Skip to content

Commit

Permalink
feat(rust): unify workers from tcp and smoltcp transport
Browse files Browse the repository at this point in the history
The unified workers were moved to ockam_transport_core
  • Loading branch information
conectado committed Jan 11, 2022
1 parent b449ea2 commit c6b1cd9
Show file tree
Hide file tree
Showing 28 changed files with 768 additions and 538 deletions.
3 changes: 3 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 15 additions & 4 deletions implementations/rust/ockam/ockam_transport_core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,28 @@ default = ["std"]

# Feature (enabled by default): "std" enables functionality expected to
# be available on a standard platform.
std = ["ockam_core/std"]
std = [
"ockam_core/std",
"ockam_node/std",
"tokio",
"tokio/io-util",
"tokio/net",
]

# Feature: "no_std" enables functionality required for platforms
# without the standard library.
no_std = ["ockam_core/no_std"]
no_std = ["ockam_core/no_std", "ockam_node/no_std"]

# Feature: "alloc" enables support for heap allocation on "no_std"
# platforms, requires nightly.
alloc = ["ockam_core/alloc"]
alloc = ["ockam_core/alloc", "ockam_node/alloc"]

[dependencies]
ockam_core = { path = "../ockam_core", version = "^0.42.1-dev", default_features = false }
tracing = { version = "0.1", default-features = false }
smoltcp = { version = "0.8", default-features = false, features = ["proto-ipv4"] }
smoltcp = { version = "0.8", default-features = false, features = [
"proto-ipv4",
] }
futures = { version = "0.3", default-features = false }
ockam_node = { path = "../ockam_node", version = "0.41.1-dev", default-features = false }
tokio = { version = "1.8", default-features = false, optional = true }
8 changes: 8 additions & 0 deletions implementations/rust/ockam/ockam_transport_core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,15 @@

pub use error::TransportError;

#[cfg(feature = "alloc")]
#[macro_use]
extern crate alloc;

mod error;
#[cfg(test)]
mod error_test;
pub mod utils;
pub mod workers;

/// TCP address type constant
pub const TCP: u8 = 1;
163 changes: 163 additions & 0 deletions implementations/rust/ockam/ockam_transport_core/src/workers/io.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
// Implements what we need from the futures::io trait(with no_std)
use core::fmt::Debug;
use core::fmt::Display;
use core::mem;
use core::pin::Pin;
use futures::{ready, Future};
use ockam_core::compat::error::Error;
use ockam_core::compat::io;
use ockam_core::compat::task::{Context, Poll};
use smoltcp;

/// Error type encapsulating either a smoltcp error or the standard error
/// Also generic errors that can come from the implementations of this module.
/// The API will always try to use the custom errors before defaulting to a wrapper from other crate.
// TODO: This probably should be absorbed by TransportError(as per the error guidelines)
#[derive(Debug)]
pub enum IoError {
/// Operation finished due to reaching the end of file prematurely.
UnexpectedEof,
/// Standard error.
StdError(io::Error),
/// Smoltcp error.
SmolError(smoltcp::Error),
}

impl Display for IoError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::UnexpectedEof => write!(f, "Unexpected EOF encoutered"),
Self::StdError(e) => Display::fmt(&e, f),
Self::SmolError(e) => Display::fmt(&e, f),
}
}
}

impl Error for IoError {}

impl From<io::Error> for IoError {
fn from(err: io::Error) -> Self {
match err.kind() {
io::ErrorKind::UnexpectedEof => return Self::UnexpectedEof,
_ => {}
}
IoError::StdError(err)
}
}

impl From<smoltcp::Error> for IoError {
fn from(err: smoltcp::Error) -> Self {
IoError::SmolError(err)
}
}

pub type Result<T> = core::result::Result<T, IoError>;

pub trait AsyncRead {
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8])
-> Poll<Result<usize>>;
}

#[cfg(feature = "std")]
impl<T: tokio::io::AsyncRead> AsyncRead for T {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<Result<usize>> {
let mut buf = tokio::io::ReadBuf::new(buf);
let res = tokio::io::AsyncRead::poll_read(self, cx, &mut buf);
res.map_ok(|()| buf.filled().len())
.map_err(|err| err.into())
}
}

pub trait AsyncWrite {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>>;
}

#[cfg(feature = "std")]
impl<T: tokio::io::AsyncWrite> AsyncWrite for T {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
tokio::io::AsyncWrite::poll_write(self, cx, buf).map_err(|err| err.into())
}
}

pub trait AsyncWriteExt: AsyncWrite {
fn write_all<'a>(&'a mut self, buf: &'a [u8]) -> WriteAll<'a, Self> {
WriteAll::new(self, buf)
}
}

pub trait AsyncReadExt: AsyncRead {
fn read_exact<'a>(&'a mut self, buf: &'a mut [u8]) -> ReadExact<'a, Self> {
ReadExact::new(self, buf)
}
}

impl<T> AsyncWriteExt for T where T: AsyncWrite {}
impl<T> AsyncReadExt for T where T: AsyncRead {}

#[derive(Debug)]
pub struct WriteAll<'a, W: AsyncWrite + ?Sized> {
writer: &'a mut W,
buf: &'a [u8],
}

impl<'a, W: AsyncWrite + ?Sized> WriteAll<'a, W> {
fn new(writer: &'a mut W, buf: &'a [u8]) -> Self {
Self { writer, buf }
}
}

#[derive(Debug)]
pub struct ReadExact<'a, R: ?Sized> {
reader: &'a mut R,
buf: &'a mut [u8],
}

impl<'a, R: AsyncRead + ?Sized> ReadExact<'a, R> {
fn new(reader: &'a mut R, buf: &'a mut [u8]) -> Self {
Self { reader, buf }
}
}

impl<'a, W: AsyncWrite + ?Sized + Unpin> Future for WriteAll<'a, W> {
type Output = Result<()>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = &mut *self;
while !this.buf.is_empty() {
let n = ready!(Pin::new(&mut *this.writer).poll_write(cx, this.buf))?;
{
let (_, rest) = mem::replace(&mut this.buf, &[]).split_at(n);
this.buf = rest;
}
if n == 0 {
return Poll::Ready(Err(IoError::UnexpectedEof));
}
}

Poll::Ready(Ok(()))
}
}

impl<'a, R: AsyncRead + ?Sized + Unpin> Future for ReadExact<'a, R> {
type Output = Result<()>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = &mut *self;
while !this.buf.is_empty() {
let n = ready!(Pin::new(&mut *this.reader).poll_read(cx, this.buf))?;

{
let (_, rest) = mem::replace(&mut this.buf, &mut []).split_at_mut(n);
this.buf = rest;
}
if n == 0 {
return Poll::Ready(Err(IoError::UnexpectedEof));
}
}
Poll::Ready(Ok(()))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
use crate::workers::traits::NoConnector;
use crate::workers::TcpSendWorker;
use crate::TransportError;
use core::fmt::Display;
use ockam_core::async_trait;
use ockam_core::{Address, Processor, Result};
use ockam_node::Context;
use tracing::{debug, trace};

use super::traits::{IntoSplit, NoHostnames, PairRegister, TcpAccepter, TcpBinder};
use super::{AsyncReadExt, AsyncWriteExt};

pub struct TcpListenProcessor<T, U> {
inner: T,
router_handle: U,
cluster_name: &'static str,
}

impl<T, U, W, X, Y, Z> TcpListenProcessor<T, U>
where
T: TcpAccepter<Stream = Z, Peer = W> + Send + Sync + 'static,
U: PairRegister<NoHostnames, W, &'static str> + Send + Sync + 'static,
Z: IntoSplit<ReadHalf = X, WriteHalf = Y> + Send + Sync + 'static,
X: AsyncReadExt + Send + Unpin + 'static,
Y: AsyncWriteExt + Send + Unpin + 'static,
W: Send + Sync + Clone + Display + 'static,
{
pub async fn start<A, B>(
ctx: &Context,
router_handle: U,
addr: A,
binder: B,
cluster_name: &'static str,
) -> Result<()>
where
B: TcpBinder<A, Listener = T>,
A: Display,
{
let waddr = Address::random(0);

debug!("Binding TcpListener to {}", addr);
let inner = binder.bind(addr).await.map_err(TransportError::from)?;
let worker = Self {
inner,
router_handle,
cluster_name,
};

ctx.start_processor(waddr, worker).await?;
Ok(())
}
}

#[async_trait]
impl<T, U, W, X, Y, Z> Processor for TcpListenProcessor<T, U>
where
T: Send + TcpAccepter<Stream = Z, Peer = W> + Sync + 'static,
U: PairRegister<NoHostnames, W, &'static str> + Send + Sync + 'static,
W: Send + Display + Sync + Clone + 'static,
Z: IntoSplit<ReadHalf = X, WriteHalf = Y> + Send + Sync + 'static,
X: AsyncReadExt + Send + Unpin + 'static,
Y: AsyncWriteExt + Send + Unpin + 'static,
{
type Context = Context;

async fn initialize(&mut self, ctx: &mut Context) -> Result<()> {
ctx.set_cluster(self.cluster_name).await
}

async fn process(&mut self, ctx: &mut Self::Context) -> Result<bool> {
trace!("Waiting for incoming TCP connection...");

// Wait for an incoming connection
let (stream, peer) = self.inner.accept().await.map_err(TransportError::from)?;

// And spawn a connection worker for it
// TODO: Here stream_connector is not really needed, in fact it's never needed when stream is passed
// Reflecting that in the API will ease the use of TcpSendWorker.
let pair = TcpSendWorker::start_pair(
ctx,
Some(stream),
NoConnector(core::marker::PhantomData::<Z>),
peer,
NoHostnames,
self.cluster_name,
)
.await?;

// Register the connection with the local TcpRouter
self.router_handle.register(&pair).await?;

Ok(true)
}
}
11 changes: 11 additions & 0 deletions implementations/rust/ockam/ockam_transport_core/src/workers/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
mod io;
mod listener;
mod receiver;
mod sender;

pub mod traits;

pub use io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, IoError, Result};
pub use listener::TcpListenProcessor;
pub use receiver::TcpRecvProcessor;
pub use sender::{TcpSendWorker, WorkerPair};
Loading

0 comments on commit c6b1cd9

Please sign in to comment.