diff --git a/Cargo.lock b/Cargo.lock index 46de1e5f0..ee7bcdd57 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1694,6 +1694,8 @@ dependencies = [ "serde", "serde_json", "spin 0.9.8", + "tracing", + "tracing-subscriber", ] [[package]] @@ -1702,6 +1704,7 @@ version = "0.0.1" dependencies = [ "anyhow", "cfg-if", + "lazy_static", "linked_list_allocator", ] diff --git a/bin/host/src/fetcher/mod.rs b/bin/host/src/fetcher/mod.rs index 2587f984c..d43ef9079 100644 --- a/bin/host/src/fetcher/mod.rs +++ b/bin/host/src/fetcher/mod.rs @@ -349,14 +349,15 @@ where }; let mut kv_write_lock = self.kv_store.write().await; - kv_write_lock.set(hash, code.into()); + kv_write_lock + .set(PreimageKey::new(*hash, PreimageKeyType::Keccak256).into(), code.into()); } HintType::StartingL2Output => { const OUTPUT_ROOT_VERSION: u8 = 0; const L2_TO_L1_MESSAGE_PASSER_ADDRESS: Address = address!("4200000000000000000000000000000000000016"); - if !hint_data.is_empty() { + if hint_data.len() != 32 { anyhow::bail!("Invalid hint data length: {}", hint_data.len()); } @@ -388,8 +389,15 @@ where raw_output[96..128].copy_from_slice(self.l2_head.as_ref()); let output_root = keccak256(raw_output); + if output_root.as_slice() != hint_data.as_ref() { + anyhow::bail!("Output root does not match L2 head."); + } + let mut kv_write_lock = self.kv_store.write().await; - kv_write_lock.set(output_root, raw_output.into()); + kv_write_lock.set( + PreimageKey::new(*output_root, PreimageKeyType::Keccak256).into(), + raw_output.into(), + ); } HintType::L2StateNode => { if hint_data.len() != 32 { @@ -410,7 +418,10 @@ where .map_err(|e| anyhow!("Failed to fetch preimage: {e}"))?; let mut kv_write_lock = self.kv_store.write().await; - kv_write_lock.set(hash, preimage.into()); + kv_write_lock.set( + PreimageKey::new(*hash, PreimageKeyType::Keccak256).into(), + preimage.into(), + ); } HintType::L2AccountProof => { if hint_data.len() != 8 + 20 { @@ -422,7 +433,7 @@ where .try_into() .map_err(|e| anyhow!("Error converting hint data to u64: {e}"))?, ); - let address = Address::from_slice(&hint_data.as_ref()[8..]); + let address = Address::from_slice(&hint_data.as_ref()[8..28]); let proof_response = self .l2_provider @@ -449,7 +460,7 @@ where .try_into() .map_err(|e| anyhow!("Error converting hint data to u64: {e}"))?, ); - let address = Address::from_slice(&hint_data.as_ref()[8..]); + let address = Address::from_slice(&hint_data.as_ref()[8..28]); let slot = B256::from_slice(&hint_data.as_ref()[28..]); let mut proof_response = self diff --git a/bin/host/src/main.rs b/bin/host/src/main.rs index b6ec726ea..496879e16 100644 --- a/bin/host/src/main.rs +++ b/bin/host/src/main.rs @@ -33,7 +33,7 @@ mod server; mod types; mod util; -#[tokio::main] +#[tokio::main(flavor = "multi_thread")] async fn main() -> Result<()> { let cfg = HostCli::parse(); init_tracing_subscriber(cfg.v)?; diff --git a/bin/host/src/server.rs b/bin/host/src/server.rs index b26d07f67..5434ca8fb 100644 --- a/bin/host/src/server.rs +++ b/bin/host/src/server.rs @@ -47,8 +47,11 @@ where /// Starts the [PreimageServer] and waits for incoming requests. pub async fn start(self) -> Result<()> { // Create the futures for the oracle server and hint router. - let server_fut = - Self::start_oracle_server(self.kv_store, self.fetcher.clone(), self.oracle_server); + let server_fut = Self::start_oracle_server( + self.kv_store.clone(), + self.fetcher.clone(), + self.oracle_server, + ); let hinter_fut = Self::start_hint_router(self.hint_reader, self.fetcher); // Spawn tasks for the futures and wait for them to complete. diff --git a/bin/host/src/util.rs b/bin/host/src/util.rs index f08444e04..ed62c0f47 100644 --- a/bin/host/src/util.rs +++ b/bin/host/src/util.rs @@ -11,7 +11,6 @@ use kona_common::FileDescriptor; use kona_preimage::PipeHandle; use reqwest::Client; use std::{fs::File, os::fd::AsRawFd}; -use tempfile::tempfile; use tokio::task::JoinHandle; /// Parses a hint from a string. @@ -34,7 +33,10 @@ pub(crate) fn parse_hint(s: &str) -> Result<(HintType, Bytes)> { /// Creates two temporary files that are connected by a pipe. pub(crate) fn create_temp_files() -> Result<(File, File)> { - let (read, write) = (tempfile().map_err(|e| anyhow!(e))?, tempfile().map_err(|e| anyhow!(e))?); + let (read, write) = ( + tempfile::tempfile().map_err(|e| anyhow!(e))?, + tempfile::tempfile().map_err(|e| anyhow!(e))?, + ); Ok((read, write)) } diff --git a/bin/programs/client/Cargo.toml b/bin/programs/client/Cargo.toml index 0f93ac4c7..50b127ccd 100644 --- a/bin/programs/client/Cargo.toml +++ b/bin/programs/client/Cargo.toml @@ -20,6 +20,7 @@ revm = { workspace = true, features = ["optimism"] } lru.workspace = true spin.workspace = true async-trait.workspace = true +tracing.workspace = true # local kona-common = { path = "../../../crates/common", version = "0.0.1" } @@ -29,6 +30,9 @@ kona-primitives = { path = "../../../crates/primitives", version = "0.0.1" } kona-mpt = { path = "../../../crates/mpt", version = "0.0.1" } kona-derive = { path = "../../../crates/derive", version = "0.0.1" } +# external +tracing-subscriber = "0.3.18" + [dev-dependencies] serde = { version = "1.0.197", features = ["derive"] } serde_json = "1.0.117" diff --git a/bin/programs/client/src/hint.rs b/bin/programs/client/src/hint.rs index de7efab93..279d357ff 100644 --- a/bin/programs/client/src/hint.rs +++ b/bin/programs/client/src/hint.rs @@ -37,7 +37,7 @@ pub enum HintType { impl HintType { /// Encodes the hint type as a string. pub fn encode_with(&self, data: &[&[u8]]) -> String { - let concatenated = data.iter().map(hex::encode).collect::>().join(" "); + let concatenated = data.iter().map(hex::encode).collect::>().join(""); alloc::format!("{} {}", self, concatenated) } } @@ -55,7 +55,7 @@ impl TryFrom<&str> for HintType { "l2-block-header" => Ok(HintType::L2BlockHeader), "l2-transactions" => Ok(HintType::L2Transactions), "l2-code" => Ok(HintType::L2Code), - "l2-output" => Ok(HintType::StartingL2Output), + "starting-l2-output" => Ok(HintType::StartingL2Output), "l2-state-node" => Ok(HintType::L2StateNode), "l2-account-proof" => Ok(HintType::L2AccountProof), "l2-account-storage-proof" => Ok(HintType::L2AccountStorageProof), diff --git a/bin/programs/client/src/l1/blob_provider.rs b/bin/programs/client/src/l1/blob_provider.rs index cdf3a3f4c..01b7375a2 100644 --- a/bin/programs/client/src/l1/blob_provider.rs +++ b/bin/programs/client/src/l1/blob_provider.rs @@ -14,7 +14,7 @@ use kona_preimage::{HintWriterClient, PreimageKey, PreimageKeyType, PreimageOrac use kona_primitives::BlockInfo; /// An oracle-backed blob provider. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct OracleBlobProvider { oracle: Arc, } @@ -70,6 +70,8 @@ impl OracleBlobProvider { blob[(i as usize) << 5..(i as usize + 1) << 5].copy_from_slice(field_element.as_ref()); } + tracing::info!(target: "client_oracle", "Retrieved blob {blob_hash:?} from the oracle."); + Ok(blob) } } diff --git a/bin/programs/client/src/l1/chain_provider.rs b/bin/programs/client/src/l1/chain_provider.rs index 5ae188890..c6725c9e7 100644 --- a/bin/programs/client/src/l1/chain_provider.rs +++ b/bin/programs/client/src/l1/chain_provider.rs @@ -14,7 +14,7 @@ use kona_preimage::{HintWriterClient, PreimageKey, PreimageKeyType, PreimageOrac use kona_primitives::BlockInfo; /// The oracle-backed L1 chain provider for the client program. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct OracleL1ChainProvider { /// The boot information boot_info: Arc, diff --git a/bin/programs/client/src/l1/driver.rs b/bin/programs/client/src/l1/driver.rs new file mode 100644 index 000000000..e38eb09ce --- /dev/null +++ b/bin/programs/client/src/l1/driver.rs @@ -0,0 +1,180 @@ +//! Contains the [DerivationDriver] struct, which handles the [L2PayloadAttributes] derivation +//! process. + +use super::{OracleBlobProvider, OracleL1ChainProvider}; +use crate::{l2::OracleL2ChainProvider, BootInfo, CachingOracle, HintType, HINT_WRITER}; +use alloc::sync::Arc; +use alloy_consensus::{Header, Sealed}; +use anyhow::{anyhow, Result}; +use core::fmt::Debug; +use kona_derive::{ + pipeline::{DerivationPipeline, Pipeline, PipelineBuilder}, + sources::EthereumDataSource, + stages::{ + AttributesQueue, BatchQueue, ChannelBank, ChannelReader, FrameQueue, L1Retrieval, + L1Traversal, StatefulAttributesBuilder, + }, + traits::{ChainProvider, L2ChainProvider}, +}; +use kona_mpt::TrieDBFetcher; +use kona_preimage::{HintWriterClient, PreimageKey, PreimageKeyType, PreimageOracleClient}; +use kona_primitives::{BlockInfo, L2AttributesWithParent, L2BlockInfo}; +use tracing::{info, warn}; + +/// An oracle-backed derivation pipeline. +pub type OraclePipeline = + DerivationPipeline, OracleL2ChainProvider>; + +/// An oracle-backed Ethereum data source. +pub type OracleDataProvider = EthereumDataSource; + +/// An oracle-backed payload attributes builder for the `AttributesQueue` stage of the derivation +/// pipeline. +pub type OracleAttributesBuilder = + StatefulAttributesBuilder; + +/// An oracle-backed attributes queue for the derivation pipeline. +pub type OracleAttributesQueue = AttributesQueue< + BatchQueue< + ChannelReader< + ChannelBank>>>, + >, + OracleL2ChainProvider, + >, + OracleAttributesBuilder, +>; + +/// The [DerivationDriver] struct is responsible for handling the [L2PayloadAttributes] derivation +/// process. +/// +/// It contains an inner [OraclePipeline] that is used to derive the attributes, backed by +/// oracle-based data sources. +#[derive(Debug)] +pub struct DerivationDriver { + /// The current L2 safe head. + l2_safe_head: L2BlockInfo, + /// The header of the L2 safe head. + l2_safe_head_header: Sealed
, + /// The inner pipeline. + pipeline: OraclePipeline, +} + +impl DerivationDriver { + /// Returns the current L2 safe head block information. + pub fn l2_safe_head(&self) -> &L2BlockInfo { + &self.l2_safe_head + } + + /// Returns the header of the current L2 safe head. + pub fn l2_safe_head_header(&self) -> &Sealed
{ + &self.l2_safe_head_header + } + + /// Creates a new [DerivationDriver] with the given configuration, blob provider, and chain + /// providers. + /// + /// ## Takes + /// - `cfg`: The rollup configuration. + /// - `blob_provider`: The blob provider. + /// - `chain_provider`: The L1 chain provider. + /// - `l2_chain_provider`: The L2 chain provider. + /// + /// ## Returns + /// - A new [DerivationDriver] instance. + pub async fn new( + boot_info: &BootInfo, + caching_oracle: &CachingOracle, + blob_provider: OracleBlobProvider, + mut chain_provider: OracleL1ChainProvider, + mut l2_chain_provider: OracleL2ChainProvider, + ) -> Result { + let cfg = Arc::new(boot_info.rollup_config.clone()); + + // Fetch the startup information. + let (l1_origin, l2_safe_head, l2_safe_head_header) = Self::find_startup_info( + caching_oracle, + boot_info, + &mut chain_provider, + &mut l2_chain_provider, + ) + .await?; + + // Construct the pipeline. + let attributes = StatefulAttributesBuilder::new( + cfg.clone(), + l2_chain_provider.clone(), + chain_provider.clone(), + ); + let dap = EthereumDataSource::new(chain_provider.clone(), blob_provider, &cfg); + let pipeline = PipelineBuilder::new() + .rollup_config(cfg) + .dap_source(dap) + .l2_chain_provider(l2_chain_provider) + .chain_provider(chain_provider) + .builder(attributes) + .origin(l1_origin) + .build(); + + Ok(Self { l2_safe_head, l2_safe_head_header, pipeline }) + } + + /// Produces the disputed [L2AttributesWithParent] payload, directly after the starting L2 + /// output root passed through the [BootInfo]. + pub async fn produce_disputed_payload(&mut self) -> Result { + // As we start the safe head at the disputed block's parent, we step the pipeline until the + // first attributes are produced. All batches at and before the safe head will be + // dropped, so the first payload will always be the disputed one. + let mut attributes = None; + while attributes.is_none() { + match self.pipeline.step(self.l2_safe_head).await { + Ok(_) => info!(target: "client_derivation_driver", "Stepped derivation pipeline"), + Err(e) => { + warn!(target: "client_derivation_driver", "Failed to step derivation pipeline: {:?}", e) + } + } + + attributes = self.pipeline.next_attributes(); + } + + Ok(attributes.expect("Must be some")) + } + + /// Finds the startup information for the derivation pipeline. + /// + /// ## Takes + /// - `caching_oracle`: The caching oracle. + /// - `boot_info`: The boot information. + /// - `chain_provider`: The L1 chain provider. + /// - `l2_chain_provider`: The L2 chain provider. + /// + /// ## Returns + /// - A tuple containing the L1 origin block information and the L2 safe head information. + async fn find_startup_info( + caching_oracle: &CachingOracle, + boot_info: &BootInfo, + chain_provider: &mut OracleL1ChainProvider, + l2_chain_provider: &mut OracleL2ChainProvider, + ) -> Result<(BlockInfo, L2BlockInfo, Sealed
)> { + // Find the initial safe head, based off of the starting L2 block number in the boot info. + HINT_WRITER + .write(&HintType::StartingL2Output.encode_with(&[boot_info.l2_output_root.as_ref()])) + .await?; + let mut output_preimage = [0u8; 128]; + caching_oracle + .get_exact( + PreimageKey::new(*boot_info.l2_output_root, PreimageKeyType::Keccak256), + &mut output_preimage, + ) + .await?; + + let safe_hash = + output_preimage[96..128].try_into().map_err(|_| anyhow!("Invalid L2 output root"))?; + let safe_header = l2_chain_provider.header_by_hash(safe_hash)?; + let safe_head_info = l2_chain_provider.l2_block_info_by_number(safe_header.number).await?; + + let l1_origin = + chain_provider.block_info_by_number(safe_head_info.l1_origin.number).await?; + + Ok((l1_origin, safe_head_info, Sealed::new_unchecked(safe_header, safe_hash))) + } +} diff --git a/bin/programs/client/src/l1/mod.rs b/bin/programs/client/src/l1/mod.rs index 355850feb..9b0a04744 100644 --- a/bin/programs/client/src/l1/mod.rs +++ b/bin/programs/client/src/l1/mod.rs @@ -1,5 +1,11 @@ //! Contains the L1 constructs of the client program. +mod driver; +pub use driver::{ + DerivationDriver, OracleAttributesBuilder, OracleAttributesQueue, OracleDataProvider, + OraclePipeline, +}; + mod blob_provider; pub use blob_provider::OracleBlobProvider; diff --git a/bin/programs/client/src/l2/chain_provider.rs b/bin/programs/client/src/l2/chain_provider.rs index 56fc82147..92cf3e90d 100644 --- a/bin/programs/client/src/l2/chain_provider.rs +++ b/bin/programs/client/src/l2/chain_provider.rs @@ -1,7 +1,7 @@ //! Contains the concrete implementation of the [L2ChainProvider] trait for the client program. use crate::{BootInfo, CachingOracle, HintType, HINT_WRITER}; -use alloc::{boxed::Box, string::ToString, sync::Arc, vec::Vec}; +use alloc::{boxed::Box, sync::Arc, vec::Vec}; use alloy_consensus::Header; use alloy_eips::eip2718::Decodable2718; use alloy_primitives::{Bytes, B256}; @@ -17,22 +17,31 @@ use kona_primitives::{ use op_alloy_consensus::OpTxEnvelope; /// The oracle-backed L2 chain provider for the client program. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct OracleL2ChainProvider { - /// The rollup configuration. - cfg: Arc, /// The boot information boot_info: Arc, /// The preimage oracle client. oracle: Arc, } +impl OracleL2ChainProvider { + /// Creates a new [OracleL2ChainProvider] with the given boot information and oracle client. + pub fn new(boot_info: Arc, oracle: Arc) -> Self { + Self { boot_info, oracle } + } +} + impl OracleL2ChainProvider { /// Returns a [Header] corresponding to the given L2 block number, by walking back from the /// L2 safe head. async fn header_by_number(&mut self, block_number: u64) -> Result
{ // Fetch the starting L2 output preimage. - HINT_WRITER.write(&HintType::StartingL2Output.to_string()).await?; + HINT_WRITER + .write( + &HintType::StartingL2Output.encode_with(&[self.boot_info.l2_output_root.as_ref()]), + ) + .await?; let output_preimage = self .oracle .get(PreimageKey::new(*self.boot_info.l2_output_root, PreimageKeyType::Keccak256)) @@ -65,7 +74,7 @@ impl L2ChainProvider for OracleL2ChainProvider { let payload = self.payload_by_number(number).await?; // Construct the system config from the payload. - payload.to_l2_block_ref(self.cfg.as_ref()) + payload.to_l2_block_ref(&self.boot_info.rollup_config) } async fn payload_by_number(&mut self, number: u64) -> Result { @@ -90,7 +99,7 @@ impl L2ChainProvider for OracleL2ChainProvider { let optimism_block = OpBlock { header, body: transactions, - withdrawals: self.cfg.is_canyon_active(timestamp).then(Vec::new), + withdrawals: self.boot_info.rollup_config.is_canyon_active(timestamp).then(Vec::new), ..Default::default() }; Ok(optimism_block.into()) diff --git a/bin/programs/client/src/main.rs b/bin/programs/client/src/main.rs index aa0e03170..bcd4c565b 100644 --- a/bin/programs/client/src/main.rs +++ b/bin/programs/client/src/main.rs @@ -5,21 +5,96 @@ #![no_std] #![cfg_attr(any(target_arch = "mips", target_arch = "riscv64"), no_main)] -use kona_client::{BootInfo, CachingOracle}; +use alloc::sync::Arc; +use alloy_consensus::Header; +use anyhow::anyhow; +use kona_client::{ + l1::{DerivationDriver, OracleBlobProvider, OracleL1ChainProvider}, + l2::{OracleL2ChainProvider, StatelessL2BlockExecutor, TrieDBHintWriter}, + BootInfo, CachingOracle, +}; use kona_common::io; use kona_common_proc::client_entry; +use tracing::Level; extern crate alloc; /// The size of the LRU cache in the oracle. -const ORACLE_LRU_SIZE: usize = 16; +const ORACLE_LRU_SIZE: usize = 1024; #[client_entry(0x77359400)] fn main() -> Result<()> { + init_tracing_subscriber(2)?; + kona_common::block_on(async move { - let caching_oracle = CachingOracle::new(ORACLE_LRU_SIZE); - let boot = BootInfo::load(&caching_oracle).await?; - io::print(&alloc::format!("{:?}\n", boot)); + //////////////////////////////////////////////////////////////// + // PROLOGUE // + //////////////////////////////////////////////////////////////// + + let oracle = Arc::new(CachingOracle::new(ORACLE_LRU_SIZE)); + let boot = Arc::new(BootInfo::load(oracle.as_ref()).await?); + let l1_provider = OracleL1ChainProvider::new(boot.clone(), oracle.clone()); + let l2_provider = OracleL2ChainProvider::new(boot.clone(), oracle.clone()); + let beacon = OracleBlobProvider::new(oracle.clone()); + + //////////////////////////////////////////////////////////////// + // DERIVATION & EXECUTION // + //////////////////////////////////////////////////////////////// + + let mut driver = DerivationDriver::new( + boot.as_ref(), + oracle.as_ref(), + beacon, + l1_provider, + l2_provider.clone(), + ) + .await?; + let attributes = driver.produce_disputed_payload().await?; + + let cfg = Arc::new(boot.rollup_config.clone()); + let mut executor = StatelessL2BlockExecutor::new( + cfg, + driver.l2_safe_head_header().clone(), + l2_provider, + TrieDBHintWriter, + ); + let Header { number, .. } = *executor.execute_payload(attributes.attributes)?; + let output_root = executor.compute_output_root()?; + + //////////////////////////////////////////////////////////////// + // EPILOGUE // + //////////////////////////////////////////////////////////////// + + assert_eq!(number, boot.l2_claim_block); + assert_eq!(output_root, boot.l2_claim); + + tracing::info!( + target: "client", + "Successfully validated L2 block #{number} with output root {output_root}", + number = number, + output_root = output_root + ); + Ok::<_, anyhow::Error>(()) }) } + +/// Initializes the tracing subscriber +/// +/// # Arguments +/// * `verbosity_level` - The verbosity level (0-4) +/// +/// # Returns +/// * `Result<()>` - Ok if successful, Err otherwise. +pub fn init_tracing_subscriber(verbosity_level: u8) -> anyhow::Result<()> { + let subscriber = tracing_subscriber::fmt() + .with_max_level(match verbosity_level { + 0 => Level::ERROR, + 1 => Level::WARN, + 2 => Level::INFO, + 3 => Level::DEBUG, + _ => Level::TRACE, + }) + .finish(); + tracing::subscriber::set_global_default(subscriber).map_err(|e| anyhow!(e)) +} diff --git a/crates/common/Cargo.toml b/crates/common/Cargo.toml index e8f9ebb98..f2a849444 100644 --- a/crates/common/Cargo.toml +++ b/crates/common/Cargo.toml @@ -15,3 +15,4 @@ cfg-if.workspace = true # external linked_list_allocator = "0.10.5" +lazy_static = "1.4.0" diff --git a/crates/common/src/io.rs b/crates/common/src/io.rs index 9337fb54b..65b0aacf6 100644 --- a/crates/common/src/io.rs +++ b/crates/common/src/io.rs @@ -59,14 +59,21 @@ mod native_io { extern crate std; use crate::{io::FileDescriptor, traits::BasicKernelInterface}; - use alloc::boxed::Box; use anyhow::{anyhow, Result}; + use lazy_static::lazy_static; use std::{ + collections::HashMap, fs::File, io::{Read, Seek, SeekFrom, Write}, os::fd::FromRawFd, + sync::Mutex, }; + lazy_static! { + static ref READ_CURSOR: Mutex> = Mutex::new(HashMap::new()); + static ref WRITE_CURSOR: Mutex> = Mutex::new(HashMap::new()); + } + /// Mock IO implementation for native tests. #[derive(Debug)] pub struct NativeIO; @@ -74,37 +81,43 @@ mod native_io { impl BasicKernelInterface for NativeIO { fn write(fd: FileDescriptor, buf: &[u8]) -> Result { let raw_fd: usize = fd.into(); - let file = unsafe { - let b = Box::new(File::from_raw_fd(raw_fd as i32)); - Box::leak(b) - }; - let n = file - .write(buf) - .map_err(|e| anyhow!("Error writing to buffer to file descriptor: {e}"))?; + let mut file = unsafe { File::from_raw_fd(raw_fd as i32) }; + + let mut cursor_entry_lock = + WRITE_CURSOR.lock().map_err(|e| anyhow!("Failed to acquire write lock: {e}"))?; + let cursor_entry = cursor_entry_lock.entry(raw_fd).or_insert(0); // Reset the cursor back to before the data we just wrote for the reader's consumption. - file.seek(SeekFrom::Current(-(n as i64))) + file.seek(SeekFrom::Start(*cursor_entry as u64)) .map_err(|e| anyhow!("Failed to reset file cursor to 0: {e}"))?; - // Attempt to sync the file to disk. This is a best-effort operation and should not - // throw if it fails. - let _ = file.sync_all(); + file.write_all(buf) + .map_err(|e| anyhow!("Error writing to buffer to file descriptor: {e}"))?; - Ok(n) + *cursor_entry += buf.len(); + + std::mem::forget(file); + + Ok(buf.len()) } fn read(fd: FileDescriptor, buf: &mut [u8]) -> Result { let raw_fd: usize = fd.into(); - let file = unsafe { - let b = Box::new(File::from_raw_fd(raw_fd as i32)); - Box::leak(b) - }; + let mut file = unsafe { File::from_raw_fd(raw_fd as i32) }; + + let mut cursor_entry_lock = + READ_CURSOR.lock().map_err(|e| anyhow!("Failed to acquire read lock: {e}"))?; + let cursor_entry = cursor_entry_lock.entry(raw_fd).or_insert(0); + + file.seek(SeekFrom::Start(*cursor_entry as u64)) + .map_err(|e| anyhow!("Failed to reset file cursor to 0: {e}"))?; + let n = file.read(buf).map_err(|e| anyhow!("Error reading from file descriptor: {e}"))?; - // Attempt to sync the file to disk. This is a best-effort operation and should not - // throw if it fails. - let _ = file.sync_all(); + *cursor_entry += n; + + std::mem::forget(file); Ok(n) } diff --git a/crates/preimage/src/oracle.rs b/crates/preimage/src/oracle.rs index 9e6833a7e..dd5b0c76f 100644 --- a/crates/preimage/src/oracle.rs +++ b/crates/preimage/src/oracle.rs @@ -39,6 +39,11 @@ impl PreimageOracleClient for OracleReader { debug!(target: "oracle_client", "Requesting data from preimage oracle. Key {key}"); let length = self.write_key(key).await?; + + if length == 0 { + return Ok(Default::default()); + } + let mut data_buffer = alloc::vec![0; length]; debug!(target: "oracle_client", "Reading data from preimage oracle. Key {key}"); @@ -66,6 +71,10 @@ impl PreimageOracleClient for OracleReader { bail!("Buffer size {} does not match preimage size {}", buf.len(), length); } + if length == 0 { + return Ok(()); + } + self.pipe_handle.read_exact(buf).await?; debug!(target: "oracle_client", "Successfully read data from preimage oracle. Key: {key}"); diff --git a/crates/preimage/src/pipe.rs b/crates/preimage/src/pipe.rs index 9d8961808..761bc5bd3 100644 --- a/crates/preimage/src/pipe.rs +++ b/crates/preimage/src/pipe.rs @@ -1,7 +1,7 @@ //! This module contains a rudamentary pipe between two file descriptors, using [kona_common::io] //! for reading and writing from the file descriptors. -use anyhow::{anyhow, Result}; +use anyhow::Result; use core::{ cell::RefCell, cmp::Ordering, @@ -76,8 +76,7 @@ impl Future for ReadFuture<'_> { self.read += chunk_read; match self.read.cmp(&buf_len) { - Ordering::Equal => Poll::Ready(Ok(self.read)), - Ordering::Greater => Poll::Ready(Err(anyhow!("Read more bytes than buffer size"))), + Ordering::Greater | Ordering::Equal => Poll::Ready(Ok(self.read)), Ordering::Less => { // Register the current task to be woken up when it can make progress ctx.waker().wake_by_ref(); @@ -105,6 +104,11 @@ impl Future for WriteFuture<'_> { Ok(0) => Poll::Ready(Ok(self.written)), // Finished writing Ok(n) => { self.written += n; + + if self.written >= self.buf.len() { + return Poll::Ready(Ok(self.written)); + } + // Register the current task to be woken up when it can make progress ctx.waker().wake_by_ref(); Poll::Pending