From 1953014b93f77e03b4434c17e657156c9166b2f8 Mon Sep 17 00:00:00 2001 From: clabby Date: Fri, 15 Nov 2024 15:22:08 -0500 Subject: [PATCH] feat(preimage): Decouple from `kona-common` (#817) --- Cargo.lock | 2 - bin/client/src/fault/mod.rs | 7 +- bin/client/src/lib.rs | 3 + .../preimage => bin/client}/src/pipe.rs | 21 ++++ bin/host/src/lib.rs | 3 +- crates/proof-sdk/preimage/Cargo.toml | 6 +- crates/proof-sdk/preimage/src/errors.rs | 19 ++- crates/proof-sdk/preimage/src/hint.rs | 111 ++++++++---------- crates/proof-sdk/preimage/src/lib.rs | 13 +- .../proof-sdk/preimage/src/native_channel.rs | 79 +++++++++++++ crates/proof-sdk/preimage/src/oracle.rs | 95 +++++++-------- crates/proof-sdk/preimage/src/test_utils.rs | 29 ----- crates/proof-sdk/preimage/src/traits.rs | 39 +++++- 13 files changed, 266 insertions(+), 161 deletions(-) rename {crates/proof-sdk/preimage => bin/client}/src/pipe.rs (87%) create mode 100644 crates/proof-sdk/preimage/src/native_channel.rs delete mode 100644 crates/proof-sdk/preimage/src/test_utils.rs diff --git a/Cargo.lock b/Cargo.lock index ca2c8e6fd..e8112fb7a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2474,8 +2474,6 @@ version = "0.0.4" dependencies = [ "alloy-primitives", "async-trait", - "kona-common", - "os_pipe", "rkyv", "serde", "thiserror 2.0.3", diff --git a/bin/client/src/fault/mod.rs b/bin/client/src/fault/mod.rs index 3f55e1338..8773f68d6 100644 --- a/bin/client/src/fault/mod.rs +++ b/bin/client/src/fault/mod.rs @@ -1,7 +1,8 @@ //! Contains FPVM-specific constructs for the `kona-client` program. +use kona_client::PipeHandle; use kona_common::FileDescriptor; -use kona_preimage::{HintWriter, OracleReader, PipeHandle}; +use kona_preimage::{HintWriter, OracleReader}; mod handler; pub(crate) use handler::fpvm_handle_register; @@ -15,7 +16,7 @@ static HINT_WRITER_PIPE: PipeHandle = PipeHandle::new(FileDescriptor::HintRead, FileDescriptor::HintWrite); /// The global preimage oracle reader. -pub(crate) static ORACLE_READER: OracleReader = OracleReader::new(ORACLE_READER_PIPE); +pub(crate) static ORACLE_READER: OracleReader = OracleReader::new(ORACLE_READER_PIPE); /// The global hint writer. -pub(crate) static HINT_WRITER: HintWriter = HintWriter::new(HINT_WRITER_PIPE); +pub(crate) static HINT_WRITER: HintWriter = HintWriter::new(HINT_WRITER_PIPE); diff --git a/bin/client/src/lib.rs b/bin/client/src/lib.rs index 8575ce8f4..36c706377 100644 --- a/bin/client/src/lib.rs +++ b/bin/client/src/lib.rs @@ -22,5 +22,8 @@ pub use hint::HintType; pub mod boot; pub use boot::BootInfo; +mod pipe; +pub use pipe::PipeHandle; + mod caching_oracle; pub use caching_oracle::{CachingOracle, FlushableCache}; diff --git a/crates/proof-sdk/preimage/src/pipe.rs b/bin/client/src/pipe.rs similarity index 87% rename from crates/proof-sdk/preimage/src/pipe.rs rename to bin/client/src/pipe.rs index 6f6cf878d..6a50f6ee4 100644 --- a/crates/proof-sdk/preimage/src/pipe.rs +++ b/bin/client/src/pipe.rs @@ -1,6 +1,8 @@ //! This module contains a rudamentary pipe between two file descriptors, using [kona_common::io] //! for reading and writing from the file descriptors. +use alloc::boxed::Box; +use async_trait::async_trait; use core::{ cell::RefCell, cmp::Ordering, @@ -9,6 +11,10 @@ use core::{ task::{Context, Poll}, }; use kona_common::{errors::IOResult, io, FileDescriptor}; +use kona_preimage::{ + errors::{ChannelError, ChannelResult}, + Channel, +}; /// [PipeHandle] is a handle for one end of a bidirectional pipe. #[derive(Debug, Clone, Copy)] @@ -51,6 +57,21 @@ impl PipeHandle { } } +#[async_trait] +impl Channel for PipeHandle { + async fn read(&self, buf: &mut [u8]) -> ChannelResult { + self.read(buf).map_err(|_| ChannelError::Closed) + } + + async fn read_exact(&self, buf: &mut [u8]) -> ChannelResult { + self.read_exact(buf).await.map_err(|_| ChannelError::Closed) + } + + async fn write(&self, buf: &[u8]) -> ChannelResult { + self.write(buf).await.map_err(|_| ChannelError::Closed) + } +} + /// A future that reads from a pipe, returning [Poll::Ready] when the buffer is full. struct ReadFuture<'a> { /// The pipe handle to read from diff --git a/bin/host/src/lib.rs b/bin/host/src/lib.rs index fc474531a..fc7b882aa 100644 --- a/bin/host/src/lib.rs +++ b/bin/host/src/lib.rs @@ -17,8 +17,9 @@ use server::PreimageServer; use anyhow::{anyhow, bail, Result}; use command_fds::{CommandFdExt, FdMapping}; use futures::FutureExt; +use kona_client::PipeHandle; use kona_common::FileDescriptor; -use kona_preimage::{HintReader, OracleServer, PipeHandle}; +use kona_preimage::{HintReader, OracleServer}; use kv::KeyValueStore; use std::{ io::{stderr, stdin, stdout}, diff --git a/crates/proof-sdk/preimage/Cargo.toml b/crates/proof-sdk/preimage/Cargo.toml index ed8dd107e..4f41cf1e5 100644 --- a/crates/proof-sdk/preimage/Cargo.toml +++ b/crates/proof-sdk/preimage/Cargo.toml @@ -18,8 +18,8 @@ thiserror.workspace = true async-trait.workspace = true alloy-primitives.workspace = true -# Workspace -kona-common.workspace = true +# `std` feature dependencies +tokio = { workspace = true, features = ["full"], optional = true } # `rkyv` feature dependencies rkyv = { workspace = true, optional = true } @@ -28,10 +28,10 @@ rkyv = { workspace = true, optional = true } serde = { workspace = true, optional = true, features = ["derive"] } [dev-dependencies] -os_pipe.workspace = true tokio = { workspace = true, features = ["full"] } [features] default = [] +std = ["dep:tokio"] rkyv = ["dep:rkyv"] serde = ["dep:serde"] diff --git a/crates/proof-sdk/preimage/src/errors.rs b/crates/proof-sdk/preimage/src/errors.rs index e52df1738..d5dabd01c 100644 --- a/crates/proof-sdk/preimage/src/errors.rs +++ b/crates/proof-sdk/preimage/src/errors.rs @@ -1,7 +1,6 @@ //! Errors for the `kona-preimage` crate. use alloc::string::String; -use kona_common::errors::IOError; use thiserror::Error; /// A [PreimageOracleError] is an enum that differentiates pipe-related errors from other errors @@ -13,7 +12,7 @@ use thiserror::Error; pub enum PreimageOracleError { /// The pipe has been broken. #[error(transparent)] - IOError(#[from] IOError), + IOError(#[from] ChannelError), /// The preimage key is invalid. #[error("Invalid preimage key.")] InvalidPreimageKey, @@ -30,3 +29,19 @@ pub enum PreimageOracleError { /// A [Result] type for the [PreimageOracleError] enum. pub type PreimageOracleResult = Result; + +/// A [ChannelError] is an enum that describes the error cases of a [Channel] trait implementation. +/// +/// [Channel]: crate::Channel +#[derive(Error, Debug)] +pub enum ChannelError { + /// The channel is closed. + #[error("Channel is closed.")] + Closed, + /// Unexpected EOF. + #[error("Unexpected EOF in channel read operation.")] + UnexpectedEOF, +} + +/// A [Result] type for the [ChannelError] enum. +pub type ChannelResult = Result; diff --git a/crates/proof-sdk/preimage/src/hint.rs b/crates/proof-sdk/preimage/src/hint.rs index 9dcb16061..1424bcedf 100644 --- a/crates/proof-sdk/preimage/src/hint.rs +++ b/crates/proof-sdk/preimage/src/hint.rs @@ -1,47 +1,46 @@ use crate::{ errors::{PreimageOracleError, PreimageOracleResult}, traits::{HintRouter, HintWriterClient}, - HintReaderServer, PipeHandle, + Channel, HintReaderServer, }; use alloc::{boxed::Box, format, string::String, vec}; use async_trait::async_trait; use tracing::{error, trace}; -/// A [HintWriter] is a high-level interface to the hint pipe. It provides a way to write hints to -/// the host. +/// A [HintWriter] is a high-level interface to the hint channel. It provides a way to write hints +/// to the host. #[derive(Debug, Clone, Copy)] -pub struct HintWriter { - pipe_handle: PipeHandle, +pub struct HintWriter { + channel: C, } -impl HintWriter { - /// Create a new [HintWriter] from a [PipeHandle]. - pub const fn new(pipe_handle: PipeHandle) -> Self { - Self { pipe_handle } +impl HintWriter { + /// Create a new [HintWriter] from a [Channel]. + pub const fn new(channel: C) -> Self { + Self { channel } } } #[async_trait] -impl HintWriterClient for HintWriter { - /// Write a hint to the host. This will overwrite any existing hint in the pipe, and block until - /// all data has been written. +impl HintWriterClient for HintWriter +where + C: Channel + Send + Sync, +{ + /// Write a hint to the host. This will overwrite any existing hint in the channel, and block + /// until all data has been written. async fn write(&self, hint: &str) -> PreimageOracleResult<()> { - // Form the hint into a byte buffer. The format is a 4-byte big-endian length prefix - // followed by the hint string. - let mut hint_bytes = vec![0u8; hint.len() + 4]; - hint_bytes[0..4].copy_from_slice(u32::to_be_bytes(hint.len() as u32).as_ref()); - hint_bytes[4..].copy_from_slice(hint.as_bytes()); - trace!(target: "hint_writer", "Writing hint \"{hint}\""); - // Write the hint to the host. - self.pipe_handle.write(&hint_bytes).await?; + // Form the hint into a byte buffer. The format is a 4-byte big-endian length prefix + // followed by the hint string. + self.channel.write(u32::to_be_bytes(hint.len() as u32).as_ref()).await?; + self.channel.write(hint.as_bytes()).await?; trace!(target: "hint_writer", "Successfully wrote hint"); // Read the hint acknowledgement from the host. let mut hint_ack = [0u8; 1]; - self.pipe_handle.read_exact(&mut hint_ack).await?; + self.channel.read_exact(&mut hint_ack).await?; trace!(target: "hint_writer", "Received hint acknowledgement"); @@ -52,36 +51,42 @@ impl HintWriterClient for HintWriter { /// A [HintReader] is a router for hints sent by the [HintWriter] from the client program. It /// provides a way for the host to prepare preimages for reading. #[derive(Debug, Clone, Copy)] -pub struct HintReader { - pipe_handle: PipeHandle, +pub struct HintReader { + channel: C, } -impl HintReader { - /// Create a new [HintReader] from a [PipeHandle]. - pub const fn new(pipe_handle: PipeHandle) -> Self { - Self { pipe_handle } +impl HintReader +where + C: Channel, +{ + /// Create a new [HintReader] from a [Channel]. + pub const fn new(channel: C) -> Self { + Self { channel } } } #[async_trait] -impl HintReaderServer for HintReader { +impl HintReaderServer for HintReader +where + C: Channel + Send + Sync, +{ async fn next_hint(&self, hint_router: &R) -> PreimageOracleResult<()> where R: HintRouter + Send + Sync, { // Read the length of the raw hint payload. let mut len_buf = [0u8; 4]; - self.pipe_handle.read_exact(&mut len_buf).await?; + self.channel.read_exact(&mut len_buf).await?; let len = u32::from_be_bytes(len_buf); // Read the raw hint payload. let mut raw_payload = vec![0u8; len as usize]; - self.pipe_handle.read_exact(raw_payload.as_mut_slice()).await?; + self.channel.read_exact(raw_payload.as_mut_slice()).await?; let payload = match String::from_utf8(raw_payload) { Ok(p) => p, Err(e) => { // Write back on error to prevent blocking the client. - self.pipe_handle.write(&[0x00]).await?; + self.channel.write(&[0x00]).await?; return Err(PreimageOracleError::Other(format!( "Failed to decode hint payload: {e}" @@ -94,14 +99,14 @@ impl HintReaderServer for HintReader { // Route the hint if let Err(e) = hint_router.route_hint(payload).await { // Write back on error to prevent blocking the client. - self.pipe_handle.write(&[0x00]).await?; + self.channel.write(&[0x00]).await?; error!("Failed to route hint: {e}"); return Err(e); } // Write back an acknowledgement to the client to unblock their process. - self.pipe_handle.write(&[0x00]).await?; + self.channel.write(&[0x00]).await?; trace!(target: "hint_reader", "Successfully routed and acknowledged hint"); @@ -112,10 +117,8 @@ impl HintReaderServer for HintReader { #[cfg(test)] mod test { use super::*; - use crate::test_utils::bidirectional_pipe; + use crate::native_channel::BidirectionalChannel; use alloc::{sync::Arc, vec::Vec}; - use kona_common::FileDescriptor; - use std::os::unix::io::AsRawFd; use tokio::sync::Mutex; struct TestRouter { @@ -143,13 +146,10 @@ mod test { async fn test_unblock_on_bad_utf8() { let mock_data = [0xf0, 0x90, 0x28, 0xbc]; - let hint_pipe = bidirectional_pipe().unwrap(); + let hint_channel = BidirectionalChannel::new::<2>().unwrap(); let client = tokio::task::spawn(async move { - let hint_writer = HintWriter::new(PipeHandle::new( - FileDescriptor::Wildcard(hint_pipe.client.read.as_raw_fd() as usize), - FileDescriptor::Wildcard(hint_pipe.client.write.as_raw_fd() as usize), - )); + let hint_writer = HintWriter::new(hint_channel.client); #[allow(invalid_from_utf8_unchecked)] hint_writer.write(unsafe { alloc::str::from_utf8_unchecked(&mock_data) }).await @@ -157,10 +157,7 @@ mod test { let host = tokio::task::spawn(async move { let router = TestRouter { incoming_hints: Default::default() }; - let hint_reader = HintReader::new(PipeHandle::new( - FileDescriptor::Wildcard(hint_pipe.host.read.as_raw_fd() as usize), - FileDescriptor::Wildcard(hint_pipe.host.write.as_raw_fd() as usize), - )); + let hint_reader = HintReader::new(hint_channel.host); hint_reader.next_hint(&router).await }); @@ -178,21 +175,15 @@ mod test { async fn test_unblock_on_fetch_failure() { const MOCK_DATA: &str = "test-hint 0xfacade"; - let hint_pipe = bidirectional_pipe().unwrap(); + let hint_channel = BidirectionalChannel::new::<2>().unwrap(); let client = tokio::task::spawn(async move { - let hint_writer = HintWriter::new(PipeHandle::new( - FileDescriptor::Wildcard(hint_pipe.client.read.as_raw_fd() as usize), - FileDescriptor::Wildcard(hint_pipe.client.write.as_raw_fd() as usize), - )); + let hint_writer = HintWriter::new(hint_channel.client); hint_writer.write(MOCK_DATA).await }); let host = tokio::task::spawn(async move { - let hint_reader = HintReader::new(PipeHandle::new( - FileDescriptor::Wildcard(hint_pipe.host.read.as_raw_fd() as usize), - FileDescriptor::Wildcard(hint_pipe.host.write.as_raw_fd() as usize), - )); + let hint_reader = HintReader::new(hint_channel.host); hint_reader.next_hint(&TestFailRouter).await }); @@ -206,13 +197,10 @@ mod test { const MOCK_DATA: &str = "test-hint 0xfacade"; let incoming_hints = Arc::new(Mutex::new(Vec::new())); - let hint_pipe = bidirectional_pipe().unwrap(); + let hint_channel = BidirectionalChannel::new::<2>().unwrap(); let client = tokio::task::spawn(async move { - let hint_writer = HintWriter::new(PipeHandle::new( - FileDescriptor::Wildcard(hint_pipe.client.read.as_raw_fd() as usize), - FileDescriptor::Wildcard(hint_pipe.client.write.as_raw_fd() as usize), - )); + let hint_writer = HintWriter::new(hint_channel.client); hint_writer.write(MOCK_DATA).await }); @@ -221,10 +209,7 @@ mod test { async move { let router = TestRouter { incoming_hints: incoming_hints_ref }; - let hint_reader = HintReader::new(PipeHandle::new( - FileDescriptor::Wildcard(hint_pipe.host.read.as_raw_fd() as usize), - FileDescriptor::Wildcard(hint_pipe.host.write.as_raw_fd() as usize), - )); + let hint_reader = HintReader::new(hint_channel.host); hint_reader.next_hint(&router).await.unwrap(); } }); diff --git a/crates/proof-sdk/preimage/src/lib.rs b/crates/proof-sdk/preimage/src/lib.rs index 94e411be7..f38b9fca3 100644 --- a/crates/proof-sdk/preimage/src/lib.rs +++ b/crates/proof-sdk/preimage/src/lib.rs @@ -5,7 +5,7 @@ )] #![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))] #![cfg_attr(not(test), warn(unused_crate_dependencies))] -#![cfg_attr(not(test), no_std)] +#![cfg_attr(not(any(test, feature = "std")), no_std)] extern crate alloc; @@ -20,14 +20,13 @@ pub use oracle::{OracleReader, OracleServer}; mod hint; pub use hint::{HintReader, HintWriter}; -mod pipe; -pub use pipe::PipeHandle; - mod traits; pub use traits::{ - CommsClient, HintReaderServer, HintRouter, HintWriterClient, PreimageFetcher, + Channel, CommsClient, HintReaderServer, HintRouter, HintWriterClient, PreimageFetcher, PreimageOracleClient, PreimageOracleServer, }; -#[cfg(test)] -mod test_utils; +#[cfg(any(test, feature = "std"))] +mod native_channel; +#[cfg(any(test, feature = "std"))] +pub use native_channel::{BidirectionalChannel, NativeChannel}; diff --git a/crates/proof-sdk/preimage/src/native_channel.rs b/crates/proof-sdk/preimage/src/native_channel.rs new file mode 100644 index 000000000..0140f54cd --- /dev/null +++ b/crates/proof-sdk/preimage/src/native_channel.rs @@ -0,0 +1,79 @@ +//! Native implementation of the [Channel] trait, backed by [tokio]'s [mpsc] channel primitives. +//! +//! [mpsc]: tokio::sync::mpsc + +use crate::{ + errors::{ChannelError, ChannelResult}, + Channel, +}; +use async_trait::async_trait; +use std::{io::Result, sync::Arc}; +use tokio::sync::{ + mpsc::{channel, Receiver, Sender}, + Mutex, +}; + +/// A bidirectional channel, allowing for synchronized communication between two parties. +#[derive(Debug)] +pub struct BidirectionalChannel { + /// The client handle of the channel. + pub client: NativeChannel, + /// The host handle of the channel. + pub host: NativeChannel, +} + +impl BidirectionalChannel { + /// Creates a [BidirectionalChannel] instance. + pub fn new() -> Result { + let (bw, ar) = channel(BUF); + let (aw, br) = channel(BUF); + + Ok(Self { + client: NativeChannel { read: Arc::new(Mutex::new(ar)), write: aw }, + host: NativeChannel { read: Arc::new(Mutex::new(br)), write: bw }, + }) + } +} + +/// A channel with a receiver and sender. +#[derive(Debug)] +pub struct NativeChannel { + /// The receiver of the channel. + pub(crate) read: Arc>>>, + /// The sender of the channel. + pub(crate) write: Sender>, +} + +#[async_trait] +impl Channel for NativeChannel { + async fn read(&self, buf: &mut [u8]) -> ChannelResult { + let data = self.read.lock().await.recv().await.ok_or(ChannelError::Closed)?; + let len = data.len().min(buf.len()); + buf[..len].copy_from_slice(&data[..len]); + Ok(len) + } + + async fn read_exact(&self, buf: &mut [u8]) -> ChannelResult { + let mut read_lock = self.read.lock().await; + + let mut read = 0; + while read < buf.len() { + let data = read_lock.recv().await.ok_or(ChannelError::Closed)?; + let len = data.len(); + + if len + read > buf.len() { + return Err(ChannelError::UnexpectedEOF); + } + + buf[read..read + len].copy_from_slice(&data[..]); + read += len; + } + + Ok(read) + } + + async fn write(&self, buf: &[u8]) -> ChannelResult { + self.write.send(buf.to_vec()).await.unwrap(); + Ok(buf.len()) + } +} diff --git a/crates/proof-sdk/preimage/src/oracle.rs b/crates/proof-sdk/preimage/src/oracle.rs index 6ea79a214..eb1ac809b 100644 --- a/crates/proof-sdk/preimage/src/oracle.rs +++ b/crates/proof-sdk/preimage/src/oracle.rs @@ -1,21 +1,24 @@ use crate::{ errors::{PreimageOracleError, PreimageOracleResult}, - traits::PreimageFetcher, - PipeHandle, PreimageKey, PreimageOracleClient, PreimageOracleServer, + traits::{Channel, PreimageFetcher}, + PreimageKey, PreimageOracleClient, PreimageOracleServer, }; use alloc::{boxed::Box, vec::Vec}; use tracing::trace; -/// An [OracleReader] is a high-level interface to the preimage oracle. +/// An [OracleReader] is a high-level interface to the preimage oracle channel. #[derive(Debug, Clone, Copy)] -pub struct OracleReader { - pipe_handle: PipeHandle, +pub struct OracleReader { + channel: C, } -impl OracleReader { - /// Create a new [OracleReader] from a [PipeHandle]. - pub const fn new(pipe_handle: PipeHandle) -> Self { - Self { pipe_handle } +impl OracleReader +where + C: Channel, +{ + /// Create a new [OracleReader] from a [Channel]. + pub const fn new(channel: C) -> Self { + Self { channel } } /// Set the preimage key for the global oracle reader. This will overwrite any existing key, and @@ -24,17 +27,20 @@ impl OracleReader { async fn write_key(&self, key: PreimageKey) -> PreimageOracleResult { // Write the key to the host so that it can prepare the preimage. let key_bytes: [u8; 32] = key.into(); - self.pipe_handle.write(&key_bytes).await?; + self.channel.write(&key_bytes).await?; // Read the length prefix and reset the cursor. let mut length_buffer = [0u8; 8]; - self.pipe_handle.read_exact(&mut length_buffer).await?; + self.channel.read_exact(&mut length_buffer).await?; Ok(u64::from_be_bytes(length_buffer) as usize) } } #[async_trait::async_trait] -impl PreimageOracleClient for OracleReader { +impl PreimageOracleClient for OracleReader +where + C: Channel + Send + Sync, +{ /// Get the data corresponding to the currently set key from the host. Return the data in a new /// heap allocated `Vec` async fn get(&self, key: PreimageKey) -> PreimageOracleResult> { @@ -50,8 +56,8 @@ impl PreimageOracleClient for OracleReader { trace!(target: "oracle_client", "Reading data from preimage oracle. Key {key}"); - // Grab a read lock on the preimage pipe to read the data. - self.pipe_handle.read_exact(&mut data_buffer).await?; + // Grab a read lock on the preimage channel to read the data. + self.channel.read_exact(&mut data_buffer).await?; trace!(target: "oracle_client", "Successfully read data from preimage oracle. Key: {key}"); @@ -77,7 +83,7 @@ impl PreimageOracleClient for OracleReader { return Ok(()); } - self.pipe_handle.read_exact(buf).await?; + self.channel.read_exact(buf).await?; trace!(target: "oracle_client", "Successfully read data from preimage oracle. Key: {key}"); @@ -87,26 +93,32 @@ impl PreimageOracleClient for OracleReader { /// An [OracleServer] is a router for the host to serve data back to the client [OracleReader]. #[derive(Debug, Clone, Copy)] -pub struct OracleServer { - pipe_handle: PipeHandle, +pub struct OracleServer { + channel: C, } -impl OracleServer { - /// Create a new [OracleServer] from a [PipeHandle]. - pub const fn new(pipe_handle: PipeHandle) -> Self { - Self { pipe_handle } +impl OracleServer +where + C: Channel, +{ + /// Create a new [OracleServer] from a [Channel]. + pub const fn new(chanel: C) -> Self { + Self { channel: chanel } } } #[async_trait::async_trait] -impl PreimageOracleServer for OracleServer { +impl PreimageOracleServer for OracleServer +where + C: Channel + Send + Sync, +{ async fn next_preimage_request(&self, fetcher: &F) -> Result<(), PreimageOracleError> where F: PreimageFetcher + Send + Sync, { // Read the preimage request from the client, and throw early if there isn't is any. let mut buf = [0u8; 32]; - self.pipe_handle.read_exact(&mut buf).await?; + self.channel.read_exact(&mut buf).await?; let preimage_key = PreimageKey::try_from(buf)?; trace!(target: "oracle_server", "Fetching preimage for key {preimage_key}"); @@ -115,12 +127,8 @@ impl PreimageOracleServer for OracleServer { let value = fetcher.get_preimage(preimage_key).await?; // Write the length as a big-endian u64 followed by the data. - let data = [(value.len() as u64).to_be_bytes().as_ref(), value.as_ref()] - .into_iter() - .flatten() - .copied() - .collect::>(); - self.pipe_handle.write(data.as_slice()).await?; + self.channel.write(value.len().to_be_bytes().as_ref()).await?; + self.channel.write(value.as_ref()).await?; trace!(target: "oracle_server", "Successfully wrote preimage data for key {preimage_key}"); @@ -131,11 +139,10 @@ impl PreimageOracleServer for OracleServer { #[cfg(test)] mod test { use super::*; - use crate::{test_utils::bidirectional_pipe, PreimageKeyType}; + use crate::{native_channel::BidirectionalChannel, PreimageKeyType}; use alloc::sync::Arc; use alloy_primitives::keccak256; - use kona_common::FileDescriptor; - use std::{collections::HashMap, os::unix::io::AsRawFd}; + use std::collections::HashMap; use tokio::sync::Mutex; struct TestFetcher { @@ -166,13 +173,10 @@ mod test { Arc::new(Mutex::new(preimages)) }; - let preimage_pipe = bidirectional_pipe().unwrap(); + let preimage_channel = BidirectionalChannel::new::<2>().unwrap(); let client = tokio::task::spawn(async move { - let oracle_reader = OracleReader::new(PipeHandle::new( - FileDescriptor::Wildcard(preimage_pipe.client.read.as_raw_fd() as usize), - FileDescriptor::Wildcard(preimage_pipe.client.write.as_raw_fd() as usize), - )); + let oracle_reader = OracleReader::new(preimage_channel.client); let mut contents_a = [0u8; 10]; let mut contents_b = [0u8; 6]; oracle_reader.get_exact(key_a, &mut contents_a).await.unwrap(); @@ -181,10 +185,7 @@ mod test { (contents_a, contents_b) }); tokio::task::spawn(async move { - let oracle_server = OracleServer::new(PipeHandle::new( - FileDescriptor::Wildcard(preimage_pipe.host.read.as_raw_fd() as usize), - FileDescriptor::Wildcard(preimage_pipe.host.write.as_raw_fd() as usize), - )); + let oracle_server = OracleServer::new(preimage_channel.host); let test_fetcher = TestFetcher { preimages: Arc::clone(&preimages) }; loop { @@ -218,23 +219,17 @@ mod test { Arc::new(Mutex::new(preimages)) }; - let preimage_pipe = bidirectional_pipe().unwrap(); + let preimage_channel = BidirectionalChannel::new::<2>().unwrap(); let client = tokio::task::spawn(async move { - let oracle_reader = OracleReader::new(PipeHandle::new( - FileDescriptor::Wildcard(preimage_pipe.client.read.as_raw_fd() as usize), - FileDescriptor::Wildcard(preimage_pipe.client.write.as_raw_fd() as usize), - )); + let oracle_reader = OracleReader::new(preimage_channel.client); let contents_a = oracle_reader.get(key_a).await.unwrap(); let contents_b = oracle_reader.get(key_b).await.unwrap(); (contents_a, contents_b) }); tokio::task::spawn(async move { - let oracle_server = OracleServer::new(PipeHandle::new( - FileDescriptor::Wildcard(preimage_pipe.host.read.as_raw_fd() as usize), - FileDescriptor::Wildcard(preimage_pipe.host.write.as_raw_fd() as usize), - )); + let oracle_server = OracleServer::new(preimage_channel.host); let test_fetcher = TestFetcher { preimages: Arc::clone(&preimages) }; loop { diff --git a/crates/proof-sdk/preimage/src/test_utils.rs b/crates/proof-sdk/preimage/src/test_utils.rs deleted file mode 100644 index 3a12d7c0c..000000000 --- a/crates/proof-sdk/preimage/src/test_utils.rs +++ /dev/null @@ -1,29 +0,0 @@ -//! Test utilities for the `kona-preimage` crate. - -use os_pipe::{PipeReader, PipeWriter}; -use std::io::Result; - -/// A bidirectional pipe, with a client and host end. -#[derive(Debug)] -pub(crate) struct BidirectionalPipe { - pub(crate) client: Pipe, - pub(crate) host: Pipe, -} - -/// A single-direction pipe, with a read and write end. -#[derive(Debug)] -pub(crate) struct Pipe { - pub(crate) read: PipeReader, - pub(crate) write: PipeWriter, -} - -/// Creates a [BidirectionalPipe] instance. -pub(crate) fn bidirectional_pipe() -> Result { - let (ar, bw) = os_pipe::pipe()?; - let (br, aw) = os_pipe::pipe()?; - - Ok(BidirectionalPipe { - client: Pipe { read: ar, write: aw }, - host: Pipe { read: br, write: bw }, - }) -} diff --git a/crates/proof-sdk/preimage/src/traits.rs b/crates/proof-sdk/preimage/src/traits.rs index f7d09db38..5614b2125 100644 --- a/crates/proof-sdk/preimage/src/traits.rs +++ b/crates/proof-sdk/preimage/src/traits.rs @@ -1,4 +1,7 @@ -use crate::{errors::PreimageOracleResult, PreimageKey}; +use crate::{ + errors::{ChannelResult, PreimageOracleResult}, + PreimageKey, +}; use alloc::{boxed::Box, string::String, vec::Vec}; use async_trait::async_trait; @@ -98,3 +101,37 @@ pub trait PreimageFetcher { /// - `Err(_)` if the preimage could not be fetched. async fn get_preimage(&self, key: PreimageKey) -> PreimageOracleResult>; } + +/// A [Channel] is a high-level interface to read and write data to a counterparty. +#[async_trait] +pub trait Channel { + /// Asynchronously read data from the channel into the provided buffer. + /// + /// # Arguments + /// - `buf`: The buffer to read data into. + /// + /// # Returns + /// - `Ok(usize)`: The number of bytes read. + /// - `Err(_)` if the data could not be read. + async fn read(&self, buf: &mut [u8]) -> ChannelResult; + + /// Asynchronously read exactly `buf.len()` bytes into `buf` from the channel. + /// + /// # Arguments + /// - `buf`: The buffer to read data into. + /// + /// # Returns + /// - `Ok(())` if the data was successfully read. + /// - `Err(_)` if the data could not be read. + async fn read_exact(&self, buf: &mut [u8]) -> ChannelResult; + + /// Asynchronously write the provided buffer to the channel. + /// + /// # Arguments + /// - `buf`: The buffer to write to the host. + /// + /// # Returns + /// - `Ok(usize)`: The number of bytes written. + /// - `Err(_)` if the data could not be written. + async fn write(&self, buf: &[u8]) -> ChannelResult; +}