Skip to content
This repository has been archived by the owner on Aug 23, 2022. It is now read-only.

poll_data_channel: implement polling instead of proxying to stream #9

Merged
merged 1 commit into from
Jun 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 191 additions & 22 deletions src/data_channel/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use util::marshal::*;
use bytes::{Buf, Bytes};
use derive_builder::Builder;
use std::fmt;
use std::future::Future;
use std::io;
use std::net::Shutdown;
use std::pin::Pin;
Expand Down Expand Up @@ -340,23 +341,68 @@ impl DataChannel {
}
}

/// Default capacity of the temporary read buffer used by [`PollStream`].
const DEFAULT_READ_BUF_SIZE: usize = 8192;

/// State of the read `Future` in [`PollStream`].
enum ReadFut {
/// Nothing in progress.
Idle,
/// Reading data from the underlying stream.
Reading(Pin<Box<dyn Future<Output = Result<Vec<u8>>> + Send>>),
/// Finished reading, but there's unread data in the temporary buffer.
RemainingData(Vec<u8>),
}

impl ReadFut {
/// Gets a mutable reference to the future stored inside `Reading(future)`.
///
/// # Panics
///
/// Panics if `ReadFut` variant is not `Reading`.
fn get_reading_mut(&mut self) -> &mut Pin<Box<dyn Future<Output = Result<Vec<u8>>> + Send>> {
match self {
ReadFut::Reading(ref mut fut) => fut,
_ => panic!("expected ReadFut to be Reading"),
}
}
}

/// A wrapper around around [`DataChannel`], which implements [`AsyncRead`] and
/// [`AsyncWrite`].
///
/// Both `poll_read` and `poll_write` calls allocate temporary buffers, which results in an
/// additional overhead.
pub struct PollDataChannel {
data_channel: Arc<DataChannel>,
poll_stream: PollStream,

read_fut: ReadFut,
write_fut: Option<Pin<Box<dyn Future<Output = Result<usize>> + Send>>>,
shutdown_fut: Option<Pin<Box<dyn Future<Output = Result<()>> + Send>>>,

read_buf_cap: usize,
}

impl PollDataChannel {
/// Constructs a new `PollDataChannel`.
///
/// # Examples
///
/// ```
/// use webrtc_data::data_channel::{DataChannel, PollDataChannel, Config};
/// use sctp::stream::Stream;
/// use std::sync::Arc;
///
/// let dc = Arc::new(DataChannel::new(Arc::new(Stream::default()), Config::default()));
/// let poll_dc = PollDataChannel::new(dc);
/// ```
pub fn new(data_channel: Arc<DataChannel>) -> Self {
let stream = data_channel.stream.clone();
Self {
data_channel,
poll_stream: PollStream::new(stream),
read_fut: ReadFut::Idle,
write_fut: None,
shutdown_fut: None,
read_buf_cap: DEFAULT_READ_BUF_SIZE,
}
}

Expand All @@ -372,44 +418,44 @@ impl PollDataChannel {

/// MessagesSent returns the number of messages sent
pub fn messages_sent(&self) -> usize {
self.data_channel.messages_sent.load(Ordering::SeqCst)
self.data_channel.messages_sent()
}

/// MessagesReceived returns the number of messages received
pub fn messages_received(&self) -> usize {
self.data_channel.messages_received.load(Ordering::SeqCst)
self.data_channel.messages_received()
}

/// BytesSent returns the number of bytes sent
pub fn bytes_sent(&self) -> usize {
self.data_channel.bytes_sent.load(Ordering::SeqCst)
self.data_channel.bytes_sent()
}

/// BytesReceived returns the number of bytes received
pub fn bytes_received(&self) -> usize {
self.data_channel.bytes_received.load(Ordering::SeqCst)
self.data_channel.bytes_received()
}

/// StreamIdentifier returns the Stream identifier associated to the stream.
pub fn stream_identifier(&self) -> u16 {
self.poll_stream.stream_identifier()
self.data_channel.stream_identifier()
}

/// BufferedAmount returns the number of bytes of data currently queued to be
/// sent over this stream.
pub fn buffered_amount(&self) -> usize {
self.poll_stream.buffered_amount()
self.data_channel.buffered_amount()
}

/// BufferedAmountLowThreshold returns the number of bytes of buffered outgoing
/// data that is considered "low." Defaults to 0.
pub fn buffered_amount_low_threshold(&self) -> usize {
self.poll_stream.buffered_amount_low_threshold()
self.data_channel.buffered_amount_low_threshold()
}

/// Set the capacity of the temporary read buffer (default: 8192).
pub fn set_read_buf_capacity(&mut self, capacity: usize) {
self.poll_stream.set_read_buf_capacity(capacity)
self.read_buf_cap = capacity
}
}

Expand All @@ -419,7 +465,68 @@ impl AsyncRead for PollDataChannel {
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.poll_stream).poll_read(cx, buf)
if buf.remaining() == 0 {
return Poll::Ready(Ok(()));
}

let fut = match self.read_fut {
ReadFut::Idle => {
// read into a temporary buffer because `buf` has an unonymous lifetime, which can
// be shorter than the lifetime of `read_fut`.
let data_channel = self.data_channel.clone();
let mut temp_buf = vec![0; self.read_buf_cap];
self.read_fut = ReadFut::Reading(Box::pin(async move {
data_channel.read(temp_buf.as_mut_slice()).await.map(|n| {
temp_buf.truncate(n);
temp_buf
})
}));
self.read_fut.get_reading_mut()
}
ReadFut::Reading(ref mut fut) => fut,
ReadFut::RemainingData(ref mut data) => {
let remaining = buf.remaining();
let len = std::cmp::min(data.len(), remaining);
buf.put_slice(&data[..len]);
if data.len() > remaining {
// ReadFut remains to be RemainingData
data.drain(..len);
} else {
self.read_fut = ReadFut::Idle;
}
return Poll::Ready(Ok(()));
}
};

loop {
match fut.as_mut().poll(cx) {
Poll::Pending => return Poll::Pending,
// retry immediately upon empty data or incomplete chunks
// since there's no way to setup a waker.
Poll::Ready(Err(Error::Sctp(sctp::Error::ErrTryAgain))) => {}
// EOF has been reached => don't touch buf and just return Ok
Poll::Ready(Err(Error::Sctp(sctp::Error::ErrEof))) => {
self.read_fut = ReadFut::Idle;
return Poll::Ready(Ok(()));
}
Poll::Ready(Err(e)) => {
self.read_fut = ReadFut::Idle;
return Poll::Ready(Err(e.into()));
}
Poll::Ready(Ok(mut temp_buf)) => {
let remaining = buf.remaining();
let len = std::cmp::min(temp_buf.len(), remaining);
buf.put_slice(&temp_buf[..len]);
if temp_buf.len() > remaining {
temp_buf.drain(..len);
self.read_fut = ReadFut::RemainingData(temp_buf);
} else {
self.read_fut = ReadFut::Idle;
}
return Poll::Ready(Ok(()));
}
}
}
}
}

Expand All @@ -429,23 +536,84 @@ impl AsyncWrite for PollDataChannel {
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.poll_stream).poll_write(cx, buf)
if buf.len() == 0 {
return Poll::Ready(Ok(0));
}

let (fut, fut_is_new) = match self.write_fut.as_mut() {
Some(fut) => (fut, false),
None => {
let data_channel = self.data_channel.clone();
let bytes = Bytes::copy_from_slice(buf);
(
self.write_fut
.get_or_insert(Box::pin(async move { data_channel.write(&bytes).await })),
true,
)
}
};

match fut.as_mut().poll(cx) {
Poll::Pending => {
// If it's the first time we're polling the future, `Poll::Pending` can't be
// returned because that would mean the `PollStream` is not ready for writing. And
// this is not true since we've just created a future, which is going to write the
// buf to the underlying stream.
//
// It's okay to return `Poll::Ready` if the data is buffered (this is what the
// buffered writer and `File` do).
if fut_is_new {
Poll::Ready(Ok(buf.len()))
} else {
// If it's the subsequent poll, it's okay to return `Poll::Pending` as it
// indicates that the `PollStream` is not ready for writing. Only one future
// can be in progress at the time.
Poll::Pending
}
}
Poll::Ready(Err(e)) => {
self.write_fut = None;
Poll::Ready(Err(e.into()))
}
Poll::Ready(Ok(n)) => {
self.write_fut = None;
Poll::Ready(Ok(n))
}
}
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.poll_stream).poll_flush(cx)
match self.write_fut.as_mut() {
Some(fut) => match fut.as_mut().poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Err(e)) => {
self.write_fut = None;
Poll::Ready(Err(e.into()))
}
Poll::Ready(Ok(_)) => {
self.write_fut = None;
Poll::Ready(Ok(()))
}
},
None => Poll::Ready(Ok(())),
}
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.poll_stream).poll_shutdown(cx)
}
let fut = match self.shutdown_fut.as_mut() {
Some(fut) => fut,
None => {
let data_channel = self.data_channel.clone();
self.shutdown_fut
.get_or_insert(Box::pin(async move { data_channel.close().await }))
}
};

fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.poll_stream).poll_write_vectored(cx, bufs)
match fut.as_mut().poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())),
Poll::Ready(Ok(_)) => Poll::Ready(Ok(())),
}
}
}

Expand All @@ -459,6 +627,7 @@ impl fmt::Debug for PollDataChannel {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("PollDataChannel")
.field("data_channel", &self.data_channel)
.field("read_buf_cap", &self.read_buf_cap)
.finish()
}
}
Expand Down
15 changes: 15 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::io;
use std::string::FromUtf8Error;
use thiserror::Error;

Expand Down Expand Up @@ -37,6 +38,20 @@ impl From<Error> for util::Error {
}
}

impl From<Error> for io::Error {
fn from(error: Error) -> Self {
match error {
e @ Error::Sctp(sctp::Error::ErrEof) => {
io::Error::new(io::ErrorKind::UnexpectedEof, e.to_string())
}
e @ Error::ErrStreamClosed => {
io::Error::new(io::ErrorKind::ConnectionAborted, e.to_string())
}
e => io::Error::new(io::ErrorKind::Other, e.to_string()),
}
}
}

impl PartialEq<util::Error> for Error {
fn eq(&self, other: &util::Error) -> bool {
if let Some(down) = other.downcast_ref::<Error>() {
Expand Down