diff --git a/Cargo.lock b/Cargo.lock index 4b76bcdd..ca56092a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2528,7 +2528,7 @@ dependencies = [ "enr 0.12.1", "fnv", "futures", - "hashlink 0.8.4", + "hashlink", "hex", "hkdf", "lazy_static", @@ -3263,12 +3263,6 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649" -[[package]] -name = "fallible-streaming-iterator" -version = "0.1.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" - [[package]] name = "fastrand" version = "2.1.1" @@ -3842,15 +3836,6 @@ dependencies = [ "hashbrown 0.14.5", ] -[[package]] -name = "hashlink" -version = "0.9.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ba4ff7128dee98c7dc9794b6a411377e1404dba1c97deb8d1a55297bd25d8af" -dependencies = [ - "hashbrown 0.14.5", -] - [[package]] name = "hdrhistogram" version = "7.5.4" @@ -4561,17 +4546,6 @@ dependencies = [ "libc", ] -[[package]] -name = "libsqlite3-sys" -version = "0.28.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c10584274047cb335c23d3e61bcef8e323adae7c5c8c760540f73610177fc3f" -dependencies = [ - "cc", - "pkg-config", - "vcpkg", -] - [[package]] name = "libz-sys" version = "1.1.20" @@ -6265,7 +6239,6 @@ dependencies = [ "rand 0.9.0-alpha.2", "rand_chacha 0.9.0-alpha.2", "redis", - "rusqlite", "serde", "serde_json", "tempfile", @@ -7850,21 +7823,6 @@ version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48fd7bd8a6377e15ad9d42a8ec25371b94ddc67abe7c8b9127bec79bebaaae18" -[[package]] -name = "rusqlite" -version = "0.31.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b838eba278d213a8beaf485bd313fd580ca4505a00d5871caeb1457c55322cae" -dependencies = [ - "bitflags 2.6.0", - "chrono", - "fallible-iterator", - "fallible-streaming-iterator", - "hashlink 0.9.1", - "libsqlite3-sys", - "smallvec", -] - [[package]] name = "rust-embed" version = "8.5.0" diff --git a/Cargo.toml b/Cargo.toml index ebac84ca..643964b9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -161,9 +161,6 @@ anyhow = "1.0" thiserror = "1.0" thiserror-no-std = "2.0.2" -# SQLite -rusqlite = { version = "0.31.0", features = ["bundled"] } - # redis redis = { version = "=0.27.3" } diff --git a/core/src/interfaces.rs b/core/src/interfaces.rs index 3e34cb55..e4848f6c 100644 --- a/core/src/interfaces.rs +++ b/core/src/interfaces.rs @@ -1,17 +1,17 @@ use crate::{merge, prover::NativeProver}; use alloy_primitives::{Address, B256}; -use clap::{Args, ValueEnum}; +use clap::Args; use raiko_lib::{ - consts::VerifierType, input::{ AggregationGuestInput, AggregationGuestOutput, BlobProofType, GuestInput, GuestOutput, }, + proof_type::ProofType, prover::{IdStore, IdWrite, Proof, ProofKey, Prover, ProverError}, }; use serde::{Deserialize, Serialize}; use serde_json::Value; use serde_with::{serde_as, DisplayFromStr}; -use std::{collections::HashMap, fmt::Display, path::Path, str::FromStr}; +use std::{collections::HashMap, fmt::Display, path::Path}; use utoipa::ToSchema; #[derive(Debug, thiserror::Error, ToSchema)] @@ -79,209 +79,121 @@ impl From for RaikoError { pub type RaikoResult = Result; -#[derive( - PartialEq, - Eq, - PartialOrd, - Ord, - Clone, - Debug, - Default, - Deserialize, - Serialize, - ToSchema, - Hash, - ValueEnum, - Copy, -)] -/// Available proof types. -pub enum ProofType { - #[default] - /// # Native - /// - /// This builds the block the same way the node does and then runs the result. - Native, - /// # Sp1 - /// - /// Uses the SP1 prover to build the block. - Sp1, - /// # Sgx - /// - /// Builds the block on a SGX supported CPU to create a proof. - Sgx, - /// # Risc0 - /// - /// Uses the RISC0 prover to build the block. - Risc0, -} - -impl std::fmt::Display for ProofType { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str(match self { - ProofType::Native => "native", - ProofType::Sp1 => "sp1", - ProofType::Sgx => "sgx", - ProofType::Risc0 => "risc0", - }) - } -} - -impl FromStr for ProofType { - type Err = RaikoError; - - fn from_str(s: &str) -> Result { - match s.trim().to_lowercase().as_str() { - "native" => Ok(ProofType::Native), - "sp1" => Ok(ProofType::Sp1), - "sgx" => Ok(ProofType::Sgx), - "risc0" => Ok(ProofType::Risc0), - _ => Err(RaikoError::InvalidProofType(s.to_string())), +/// Run the prover driver depending on the proof type. +pub async fn run_prover( + proof_type: ProofType, + input: GuestInput, + output: &GuestOutput, + config: &Value, + store: Option<&mut dyn IdWrite>, +) -> RaikoResult { + match proof_type { + ProofType::Native => NativeProver::run(input.clone(), output, config, store) + .await + .map_err(>::into), + ProofType::Sp1 => { + #[cfg(feature = "sp1")] + return sp1_driver::Sp1Prover::run(input.clone(), output, config, store) + .await + .map_err(|e| e.into()); + #[cfg(not(feature = "sp1"))] + Err(RaikoError::FeatureNotSupportedError(proof_type)) } - } -} - -impl TryFrom for ProofType { - type Error = RaikoError; - - fn try_from(value: u8) -> Result { - match value { - 0 => Ok(Self::Native), - 1 => Ok(Self::Sp1), - 2 => Ok(Self::Sgx), - 3 => Ok(Self::Risc0), - _ => Err(RaikoError::Conversion("Invalid u8".to_owned())), + ProofType::Risc0 => { + #[cfg(feature = "risc0")] + return risc0_driver::Risc0Prover::run(input.clone(), output, config, store) + .await + .map_err(|e| e.into()); + #[cfg(not(feature = "risc0"))] + Err(RaikoError::FeatureNotSupportedError(proof_type)) + } + ProofType::Sgx => { + #[cfg(feature = "sgx")] + return sgx_prover::SgxProver::run(input.clone(), output, config, store) + .await + .map_err(|e| e.into()); + #[cfg(not(feature = "sgx"))] + Err(RaikoError::FeatureNotSupportedError(proof_type)) } } } -impl From for VerifierType { - fn from(val: ProofType) -> Self { - match val { - ProofType::Native => VerifierType::None, - ProofType::Sp1 => VerifierType::SP1, - ProofType::Sgx => VerifierType::SGX, - ProofType::Risc0 => VerifierType::RISC0, +/// Run the prover driver depending on the proof type. +pub async fn aggregate_proofs( + proof_type: ProofType, + input: AggregationGuestInput, + output: &AggregationGuestOutput, + config: &Value, + store: Option<&mut dyn IdWrite>, +) -> RaikoResult { + let proof = match proof_type { + ProofType::Native => NativeProver::aggregate(input.clone(), output, config, store) + .await + .map_err(>::into), + ProofType::Sp1 => { + #[cfg(feature = "sp1")] + return sp1_driver::Sp1Prover::aggregate(input.clone(), output, config, store) + .await + .map_err(|e| e.into()); + #[cfg(not(feature = "sp1"))] + Err(RaikoError::FeatureNotSupportedError(proof_type)) } - } + ProofType::Risc0 => { + #[cfg(feature = "risc0")] + return risc0_driver::Risc0Prover::aggregate(input.clone(), output, config, store) + .await + .map_err(|e| e.into()); + #[cfg(not(feature = "risc0"))] + Err(RaikoError::FeatureNotSupportedError(proof_type)) + } + ProofType::Sgx => { + #[cfg(feature = "sgx")] + return sgx_prover::SgxProver::aggregate(input.clone(), output, config, store) + .await + .map_err(|e| e.into()); + #[cfg(not(feature = "sgx"))] + Err(RaikoError::FeatureNotSupportedError(proof_type)) + } + }?; + + Ok(proof) } -impl ProofType { - /// Run the prover driver depending on the proof type. - pub async fn run_prover( - &self, - input: GuestInput, - output: &GuestOutput, - config: &Value, - store: Option<&mut dyn IdWrite>, - ) -> RaikoResult { - match self { - ProofType::Native => NativeProver::run(input.clone(), output, config, store) +pub async fn cancel_proof( + proof_type: ProofType, + proof_key: ProofKey, + read: Box<&mut dyn IdStore>, +) -> RaikoResult<()> { + match proof_type { + ProofType::Native => NativeProver::cancel(proof_key, read) + .await + .map_err(>::into), + ProofType::Sp1 => { + #[cfg(feature = "sp1")] + return sp1_driver::Sp1Prover::cancel(proof_key, read) .await - .map_err(>::into), - ProofType::Sp1 => { - #[cfg(feature = "sp1")] - return sp1_driver::Sp1Prover::run(input.clone(), output, config, store) - .await - .map_err(|e| e.into()); - #[cfg(not(feature = "sp1"))] - Err(RaikoError::FeatureNotSupportedError(*self)) - } - ProofType::Risc0 => { - #[cfg(feature = "risc0")] - return risc0_driver::Risc0Prover::run(input.clone(), output, config, store) - .await - .map_err(|e| e.into()); - #[cfg(not(feature = "risc0"))] - Err(RaikoError::FeatureNotSupportedError(*self)) - } - ProofType::Sgx => { - #[cfg(feature = "sgx")] - return sgx_prover::SgxProver::run(input.clone(), output, config, store) - .await - .map_err(|e| e.into()); - #[cfg(not(feature = "sgx"))] - Err(RaikoError::FeatureNotSupportedError(*self)) - } + .map_err(|e| e.into()); + #[cfg(not(feature = "sp1"))] + Err(RaikoError::FeatureNotSupportedError(proof_type)) } - } - - /// Run the prover driver depending on the proof type. - pub async fn aggregate_proofs( - &self, - input: AggregationGuestInput, - output: &AggregationGuestOutput, - config: &Value, - store: Option<&mut dyn IdWrite>, - ) -> RaikoResult { - let proof = match self { - ProofType::Native => NativeProver::aggregate(input.clone(), output, config, store) + ProofType::Risc0 => { + #[cfg(feature = "risc0")] + return risc0_driver::Risc0Prover::cancel(proof_key, read) .await - .map_err(>::into), - ProofType::Sp1 => { - #[cfg(feature = "sp1")] - return sp1_driver::Sp1Prover::aggregate(input.clone(), output, config, store) - .await - .map_err(|e| e.into()); - #[cfg(not(feature = "sp1"))] - Err(RaikoError::FeatureNotSupportedError(*self)) - } - ProofType::Risc0 => { - #[cfg(feature = "risc0")] - return risc0_driver::Risc0Prover::aggregate(input.clone(), output, config, store) - .await - .map_err(|e| e.into()); - #[cfg(not(feature = "risc0"))] - Err(RaikoError::FeatureNotSupportedError(*self)) - } - ProofType::Sgx => { - #[cfg(feature = "sgx")] - return sgx_prover::SgxProver::aggregate(input.clone(), output, config, store) - .await - .map_err(|e| e.into()); - #[cfg(not(feature = "sgx"))] - Err(RaikoError::FeatureNotSupportedError(*self)) - } - }?; - - Ok(proof) - } - - pub async fn cancel_proof( - &self, - proof_key: ProofKey, - read: Box<&mut dyn IdStore>, - ) -> RaikoResult<()> { - match self { - ProofType::Native => NativeProver::cancel(proof_key, read) + .map_err(|e| e.into()); + #[cfg(not(feature = "risc0"))] + Err(RaikoError::FeatureNotSupportedError(proof_type)) + } + ProofType::Sgx => { + #[cfg(feature = "sgx")] + return sgx_prover::SgxProver::cancel(proof_key, read) .await - .map_err(>::into), - ProofType::Sp1 => { - #[cfg(feature = "sp1")] - return sp1_driver::Sp1Prover::cancel(proof_key, read) - .await - .map_err(|e| e.into()); - #[cfg(not(feature = "sp1"))] - Err(RaikoError::FeatureNotSupportedError(*self)) - } - ProofType::Risc0 => { - #[cfg(feature = "risc0")] - return risc0_driver::Risc0Prover::cancel(proof_key, read) - .await - .map_err(|e| e.into()); - #[cfg(not(feature = "risc0"))] - Err(RaikoError::FeatureNotSupportedError(*self)) - } - ProofType::Sgx => { - #[cfg(feature = "sgx")] - return sgx_prover::SgxProver::cancel(proof_key, read) - .await - .map_err(|e| e.into()); - #[cfg(not(feature = "sgx"))] - Err(RaikoError::FeatureNotSupportedError(*self)) - } - }?; - Ok(()) - } + .map_err(|e| e.into()); + #[cfg(not(feature = "sgx"))] + Err(RaikoError::FeatureNotSupportedError(proof_type)) + } + }?; + Ok(()) } #[serde_as] diff --git a/core/src/lib.rs b/core/src/lib.rs index d424568a..46113094 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -2,6 +2,7 @@ use std::{collections::HashMap, hint::black_box}; use alloy_primitives::Address; use alloy_rpc_types::EIP1186AccountProofResponse; +use interfaces::{cancel_proof, run_prover}; use raiko_lib::{ builder::{create_mem_db, RethBlockBuilder}, consts::ChainSpec, @@ -110,10 +111,7 @@ impl Raiko { store: Option<&mut dyn IdWrite>, ) -> RaikoResult { let config = serde_json::to_value(&self.request)?; - self.request - .proof_type - .run_prover(input, output, &config, store) - .await + run_prover(self.request.proof_type, input, output, &config, store).await } pub async fn cancel( @@ -121,7 +119,7 @@ impl Raiko { proof_key: ProofKey, read: Box<&mut dyn IdStore>, ) -> RaikoResult<()> { - self.request.proof_type.cancel_proof(proof_key, read).await + cancel_proof(self.request.proof_type, proof_key, read).await } } @@ -216,26 +214,23 @@ pub fn merge(a: &mut Value, b: &Value) { #[cfg(test)] mod tests { - use crate::{ - interfaces::{ProofRequest, ProofType}, - provider::rpc::RpcBlockDataProvider, - ChainSpec, Raiko, - }; + use crate::interfaces::aggregate_proofs; + use crate::{interfaces::ProofRequest, provider::rpc::RpcBlockDataProvider, ChainSpec, Raiko}; use alloy_primitives::Address; use alloy_provider::Provider; - use clap::ValueEnum; use raiko_lib::{ consts::{Network, SupportedChainSpecs}, input::{AggregationGuestInput, AggregationGuestOutput, BlobProofType}, primitives::B256, + proof_type::ProofType, prover::Proof, }; use serde_json::{json, Value}; - use std::{collections::HashMap, env}; + use std::{collections::HashMap, env, str::FromStr}; fn get_proof_type_from_env() -> ProofType { let proof_type = env::var("TARGET").unwrap_or("native".to_string()); - ProofType::from_str(&proof_type, true).unwrap() + ProofType::from_str(&proof_type).unwrap() } fn is_ci() -> bool { @@ -474,15 +469,15 @@ mod tests { let output = AggregationGuestOutput { hash: B256::ZERO }; - let aggregated_proof = proof_type - .aggregate_proofs( - input, - &output, - &serde_json::to_value(&test_proof_params(false)).unwrap(), - None, - ) - .await - .expect("proof aggregation failed"); + let aggregated_proof = aggregate_proofs( + proof_type, + input, + &output, + &serde_json::to_value(&test_proof_params(false)).unwrap(), + None, + ) + .await + .expect("proof aggregation failed"); println!("aggregated proof: {aggregated_proof:?}"); } } diff --git a/host/src/cache.rs b/host/src/cache.rs index 52fe34a5..606a7f4a 100644 --- a/host/src/cache.rs +++ b/host/src/cache.rs @@ -79,15 +79,12 @@ mod test { use alloy_primitives::{Address, B256}; use alloy_provider::Provider; - use raiko_core::{ - interfaces::{ProofRequest, ProofType}, - provider::rpc::RpcBlockDataProvider, - Raiko, - }; + use raiko_core::{interfaces::ProofRequest, provider::rpc::RpcBlockDataProvider, Raiko}; use raiko_lib::input::BlobProofType; use raiko_lib::{ consts::{ChainSpec, Network, SupportedChainSpecs}, input::GuestInput, + proof_type::ProofType, }; async fn create_cache_input( diff --git a/host/src/interfaces.rs b/host/src/interfaces.rs index 4800bb71..9256009e 100644 --- a/host/src/interfaces.rs +++ b/host/src/interfaces.rs @@ -1,5 +1,5 @@ use axum::response::IntoResponse; -use raiko_core::interfaces::ProofType; +use raiko_lib::proof_type::ProofType; use raiko_lib::prover::ProverError; use raiko_tasks::{TaskManagerError, TaskStatus}; use tokio::sync::mpsc::error::TrySendError; diff --git a/host/src/lib.rs b/host/src/lib.rs index b50c955d..6dbfd253 100644 --- a/host/src/lib.rs +++ b/host/src/lib.rs @@ -75,10 +75,6 @@ pub struct Opts { /// Set jwt secret for auth pub jwt_secret: Option, - #[arg(long, require_equals = true, default_value = "raiko.sqlite")] - /// Set the path to the sqlite db file - pub sqlite_file: PathBuf, - #[arg(long, require_equals = true, default_value = "1048576")] pub max_db_size: usize, @@ -132,7 +128,6 @@ impl Opts { impl From for TaskManagerOpts { fn from(val: Opts) -> Self { Self { - sqlite_file: val.sqlite_file, max_db_size: val.max_db_size, redis_url: val.redis_url.to_string(), redis_ttl: val.redis_ttl, @@ -143,7 +138,6 @@ impl From for TaskManagerOpts { impl From<&Opts> for TaskManagerOpts { fn from(val: &Opts) -> Self { Self { - sqlite_file: val.sqlite_file.clone(), max_db_size: val.max_db_size, redis_url: val.redis_url.to_string(), redis_ttl: val.redis_ttl, diff --git a/host/src/metrics.rs b/host/src/metrics.rs index 307a1992..26babdd6 100644 --- a/host/src/metrics.rs +++ b/host/src/metrics.rs @@ -5,7 +5,7 @@ use prometheus::{ labels, register_histogram_vec, register_int_counter_vec, register_int_gauge, HistogramVec, IntCounterVec, IntGauge, }; -use raiko_core::interfaces::ProofType; +use raiko_lib::proof_type::ProofType; lazy_static! { pub static ref HOST_REQ_COUNT: IntCounterVec = register_int_counter_vec!( diff --git a/host/src/proof.rs b/host/src/proof.rs index 70da9574..1223af0c 100644 --- a/host/src/proof.rs +++ b/host/src/proof.rs @@ -4,15 +4,17 @@ use std::{ sync::Arc, }; -use anyhow::anyhow; use raiko_core::{ - interfaces::{AggregationOnlyRequest, ProofRequest, ProofType, RaikoError}, + interfaces::{ + aggregate_proofs, cancel_proof, AggregationOnlyRequest, ProofRequest, RaikoError, + }, provider::{get_task_data, rpc::RpcBlockDataProvider}, Raiko, }; use raiko_lib::{ consts::SupportedChainSpecs, input::{AggregationGuestInput, AggregationGuestOutput}, + proof_type::ProofType, prover::{IdWrite, Proof}, Measurement, }; @@ -91,25 +93,25 @@ impl ProofActor { }; let mut manager = get_task_manager(&self.opts.clone().into()); - key.proof_system - .cancel_proof( - ( - key.chain_id, - key.block_id, - key.blockhash, - key.proof_system as u8, - ), - Box::new(&mut manager), - ) - .await - .or_else(|e| { - if e.to_string().contains("No data for query") { - warn!("Task already cancelled or not yet started!"); - Ok(()) - } else { - Err::<(), HostError>(e.into()) - } - })?; + cancel_proof( + key.proof_system, + ( + key.chain_id, + key.block_id, + key.blockhash, + key.proof_system as u8, + ), + Box::new(&mut manager), + ) + .await + .or_else(|e| { + if e.to_string().contains("No data for query") { + warn!("Task already cancelled or not yet started!"); + Ok(()) + } else { + Err::<(), HostError>(e.into()) + } + })?; task.cancel(); Ok(()) } @@ -341,6 +343,9 @@ impl ProofActor { } pub async fn handle_aggregate(request: AggregationOnlyRequest, opts: &Opts) -> HostResult<()> { + let proof_type_str = request.proof_type.to_owned().unwrap_or_default(); + let proof_type = ProofType::from_str(&proof_type_str).map_err(HostError::Conversion)?; + let mut manager = get_task_manager(&opts.clone().into()); let status = manager @@ -356,12 +361,7 @@ impl ProofActor { manager .update_aggregation_task_progress(&request, TaskStatus::WorkInProgress, None) .await?; - let proof_type = ProofType::from_str( - request - .proof_type - .as_ref() - .ok_or_else(|| anyhow!("No proof type"))?, - )?; + let input = AggregationGuestInput { proofs: request.clone().proofs, }; @@ -369,16 +369,14 @@ impl ProofActor { let config = serde_json::to_value(request.clone().prover_args)?; let mut manager = get_task_manager(&opts.clone().into()); - let (status, proof) = match proof_type - .aggregate_proofs(input, &output, &config, Some(&mut manager)) - .await - { - Err(error) => { - error!("{error}"); - (HostError::from(error).into(), None) - } - Ok(proof) => (TaskStatus::Success, Some(serde_json::to_vec(&proof)?)), - }; + let (status, proof) = + match aggregate_proofs(proof_type, input, &output, &config, Some(&mut manager)).await { + Err(error) => { + error!("{error}"); + (HostError::from(error).into(), None) + } + Ok(proof) => (TaskStatus::Success, Some(serde_json::to_vec(&proof)?)), + }; manager .update_aggregation_task_progress(&request, status, proof.as_deref()) diff --git a/host/src/server/api/v3/proof/aggregate.rs b/host/src/server/api/v3/proof/aggregate.rs index a346c204..20973dbe 100644 --- a/host/src/server/api/v3/proof/aggregate.rs +++ b/host/src/server/api/v3/proof/aggregate.rs @@ -1,7 +1,8 @@ use std::str::FromStr; use axum::{debug_handler, extract::State, routing::post, Json, Router}; -use raiko_core::interfaces::{AggregationOnlyRequest, ProofType}; +use raiko_core::interfaces::AggregationOnlyRequest; +use raiko_lib::proof_type::ProofType; use raiko_tasks::{TaskManager, TaskStatus}; use utoipa::OpenApi; @@ -9,6 +10,7 @@ use crate::{ interfaces::HostResult, metrics::{inc_current_req, inc_guest_req_count, inc_host_req_count}, server::api::v3::Status, + server::HostError, Message, ProverState, }; @@ -42,7 +44,8 @@ async fn aggregation_handler( .proof_type .as_deref() .unwrap_or_default(), - )?; + ) + .map_err(HostError::Conversion)?; inc_host_req_count(0); inc_guest_req_count(&proof_type, 0); diff --git a/host/tests/common/mod.rs b/host/tests/common/mod.rs index da9f95be..e903466e 100644 --- a/host/tests/common/mod.rs +++ b/host/tests/common/mod.rs @@ -1,8 +1,9 @@ use std::str::FromStr; -use raiko_core::interfaces::{ProofRequestOpt, ProofType, ProverSpecificOpts}; +use raiko_core::interfaces::{ProofRequestOpt, ProverSpecificOpts}; use raiko_host::{server::serve, ProverState}; use raiko_lib::consts::{Network, SupportedChainSpecs}; +use raiko_lib::proof_type::ProofType; use serde::Deserialize; use tokio_util::sync::CancellationToken; diff --git a/lib/src/consts.rs b/lib/src/consts.rs index 08297a4e..b73ed43f 100644 --- a/lib/src/consts.rs +++ b/lib/src/consts.rs @@ -17,6 +17,8 @@ use once_cell::sync::Lazy; use std::path::PathBuf; use std::{collections::HashMap, env::var}; +use crate::proof_type::ProofType; + /// U256 representation of 0. pub const ZERO: U256 = U256::ZERO; /// U256 representation of 1. @@ -136,6 +138,17 @@ pub enum VerifierType { RISC0, } +impl From for VerifierType { + fn from(val: ProofType) -> Self { + match val { + ProofType::Native => VerifierType::None, + ProofType::Sgx => VerifierType::SGX, + ProofType::Sp1 => VerifierType::SP1, + ProofType::Risc0 => VerifierType::RISC0, + } + } +} + /// Specification of a specific chain. #[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)] pub struct ChainSpec { diff --git a/lib/src/lib.rs b/lib/src/lib.rs index 5042d9ae..873f5849 100644 --- a/lib/src/lib.rs +++ b/lib/src/lib.rs @@ -23,6 +23,7 @@ pub mod consts; pub mod input; pub mod mem_db; pub mod primitives; +pub mod proof_type; pub mod protocol_instance; pub mod prover; pub mod utils; diff --git a/lib/src/proof_type.rs b/lib/src/proof_type.rs new file mode 100644 index 00000000..c8a261fb --- /dev/null +++ b/lib/src/proof_type.rs @@ -0,0 +1,64 @@ +use serde::{Deserialize, Serialize}; + +#[derive( + PartialEq, Eq, PartialOrd, Ord, Clone, Debug, Default, Deserialize, Serialize, Hash, Copy, +)] +/// Available proof types. +pub enum ProofType { + #[default] + /// # Native + /// + /// This builds the block the same way the node does and then runs the result. + Native, + /// # Sp1 + /// + /// Uses the SP1 prover to build the block. + Sp1, + /// # Sgx + /// + /// Builds the block on a SGX supported CPU to create a proof. + Sgx, + /// # Risc0 + /// + /// Uses the RISC0 prover to build the block. + Risc0, +} + +impl std::fmt::Display for ProofType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(match self { + ProofType::Native => "native", + ProofType::Sp1 => "sp1", + ProofType::Sgx => "sgx", + ProofType::Risc0 => "risc0", + }) + } +} + +impl std::str::FromStr for ProofType { + type Err = String; + + fn from_str(s: &str) -> Result { + match s.trim().to_lowercase().as_str() { + "native" => Ok(ProofType::Native), + "sp1" => Ok(ProofType::Sp1), + "sgx" => Ok(ProofType::Sgx), + "risc0" => Ok(ProofType::Risc0), + _ => Err(format!("Unknown proof type {}", s)), + } + } +} + +impl TryFrom for ProofType { + type Error = String; + + fn try_from(value: u8) -> Result { + match value { + 0 => Ok(Self::Native), + 1 => Ok(Self::Sp1), + 2 => Ok(Self::Sgx), + 3 => Ok(Self::Risc0), + _ => Err(format!("Unknown proof type {}", value)), + } + } +} diff --git a/taskdb/Cargo.toml b/taskdb/Cargo.toml index 1ccdb19f..13c99dea 100644 --- a/taskdb/Cargo.toml +++ b/taskdb/Cargo.toml @@ -7,7 +7,6 @@ edition = "2021" [dependencies] raiko-lib = { workspace = true } raiko-core = { workspace = true } -rusqlite = { workspace = true, features = ["chrono"], optional = true } num_enum = { workspace = true } chrono = { workspace = true, features = ["serde"] } thiserror = { workspace = true } @@ -26,11 +25,9 @@ rand = "0.9.0-alpha.1" # This is an a rand_chacha = "0.9.0-alpha.1" tempfile = "3.10.1" alloy-primitives = { workspace = true, features = ["getrandom"] } -rusqlite = { workspace = true, features = ["trace"] } [features] default = [] -sqlite = ["rusqlite"] in-memory = [] redis-db = ["redis"] diff --git a/taskdb/src/adv_sqlite.rs b/taskdb/src/adv_sqlite.rs deleted file mode 100644 index a33fc172..00000000 --- a/taskdb/src/adv_sqlite.rs +++ /dev/null @@ -1,1029 +0,0 @@ -// Raiko -// Copyright (c) 2024 Taiko Labs -// Licensed and distributed under either of -// * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT). -// * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0). -// at your option. This file may not be copied, modified, or distributed except according to those terms. - -//! # Raiko Task Manager -//! -//! At the moment (Apr '24) proving requires a significant amount of time -//! and maintaining a connection with a potentially external party. -//! -//! By design Raiko is stateless, it prepares inputs and forward to the various proof systems. -//! However some proving backend like Risc0's Bonsai are also stateless, -//! and only accepts proofs and return result. -//! Hence to handle crashes, networking losses and restarts, we need to persist -//! the status of proof requests, task submitted, proof received, proof forwarded. -//! -//! In the diagram: -//! _____________ ______________ _______________ -//! Taiko L2 -> | Taiko-geth | ======> | Raiko-host | =========> | Raiko-guests | -//! | Taiko-reth | | | | Risc0 | -//! |____________| |_____________| | SGX | -//! | SP1 | -//! |______________| -//! _____________________________ -//! =========> | Prover Networks | -//! | Risc0's Bonsai | -//! | Succinct's Prover Network | -//! |____________________________| -//! _________________________ -//! =========> | Raiko-dist | -//! | Distributed Risc0 | -//! | Distributed SP1 | -//! |_______________________| -//! -//! We would position Raiko task manager either before Raiko-host or after Raiko-host. -//! -//! ## Implementation -//! -//! The task manager is a set of tables and KV-stores. -//! - Keys for table joins are prefixed with id -//! - KV-stores for (almost) immutable data -//! - KV-store for large inputs and indistinguishable from random proofs -//! - Tables for tasks and their metadata. -//! -//! __________________________ -//! | metadata | -//! |_________________________| A simple KV-store with the DB version for migration/upgrade detection. -//! | Key | Value | Future version may add new fields, without breaking older versions. -//! |_________________|_______| -//! | task_db_version | 0 | -//! |_________________|_______| -//! -//! ________________________ -//! | Proof systems | -//! |______________________| A map: ID -> proof systems -//! | id_proofsys | Desc | -//! |_____________|________| -//! | 0 | Risc0 | (0 for Risc0 and 1 for SP1 is intentional) -//! | 1 | SP1 | -//! | 2 | SGX | -//! |_____________|________| -//! -//! _________________________________________________ -//! | Task Status code | -//! |________________________________________________| -//! | id_status | Desc | -//! |_____________|__________________________________| -//! | 0 | Success | -//! | 1000 | Registered | -//! | 2000 | Work-in-progress | -//! | | | -//! | -1000 | Proof failure (prover - generic) | -//! | -1100 | Proof failure (OOM) | -//! | | | -//! | -2000 | Network failure | -//! | | | -//! | -3000 | Cancelled | -//! | -3100 | Cancelled (never started) | -//! | -3200 | Cancelled (aborted) | -//! | -3210 | Cancellation in progress | (Yes -3210 is intentional ;)) -//! | | | -//! | -4000 | Invalid or unsupported block | -//! | | | -//! | -9999 | Unspecified failure reason | -//! |_____________|__________________________________| -//! -//! Rationale: -//! - Convention, failures use negative status code. -//! - We leave space for new status codes -//! - -X000 status code are for generic failures segregated by failures: -//! on the networking side, the prover side or trying to prove an invalid block. -//! -//! A catchall -9999 error code is provided if a failure is not due to -//! either the network, the prover or the requester invalid block. -//! They should not exist in the DB and a proper analysis -//! and eventually status code should be assigned. -//! -//! ________________________________________________________________________________________________ -//! | Tasks metadata | -//! |________________________________________________________________________________________________| -//! | id_task | chain_id | block_number | blockhash | parent_hash | state_root | # of txs | gas_used | -//! |_________|__________|______________|___________|_____________|____________|__________|__________| -//! ____________________________________ -//! | Task queue | -//! |___________________________________| -//! | id_task | blockhash | id_proofsys | -//! |_________|___________|_____________| -//! ______________________________________ -//! | Task payloads | -//! |_____________________________________| -//! | id_task | inputs (serialized) | -//! |_________|___________________________| -//! _____________________________________ -//! | Task requests | -//! |____________________________________| -//! | id_task | id_submitter | timestamp | -//! |_________|______________|___________| -//! ___________________________________________________________________________________ -//! | Task progress trail | -//! |__________________________________________________________________________________| -//! | id_task | third_party | id_status | timestamp | -//! |_________|________________________|_________________________|_____________________| -//! | 101 | 'Based Proposer" | 1000 (Registered) | 2024-01-01 00:00:01 | -//! | 101 | 'A Prover Network' | 2000 (WIP) | 2024-01-01 00:00:01 | -//! | 101 | 'A Prover Network' | -2000 (Network failure) | 2024-01-01 00:02:00 | -//! | 101 | 'Proof in the Pudding' | 2000 (WIP) | 2024-01-01 00:02:30 | -//!·| 101 | 'Proof in the Pudding' | 0 (Success) | 2024-01-01 01:02:30 | -//! -//! Rationale: -//! - payloads are very large and warrant a dedicated table, with pruning -//! - metadata is useful to audit block building and prover efficiency -//! - Due to failures and retries, we may submit the same task to multiple fulfillers -//! or retry with the same fulfiller so we keep an audit trail of events. -//! -//! ____________________________ -//! | Proof cache | A map: ID -> proof -//! |___________________________| -//! | id_task | proof_value | -//! |__________|________________| A Groth16 proof is 2G₁+1G₂ elements -//! | 0 | 0xabcd...6789 | On BN254: 2*(2*32)+1*(2*2*32) = 256 bytes -//! | 1 | 0x1234...cdef | -//! | ... | ... | A SGX proof is ... -//! |__________|________________| A Stark proof (not wrapped in Groth16) would be several kilobytes -//! -//! Do we need pruning? -//! There are 60s * 60min * 24h * 30j = 2592000s in a month -//! dividing by 12, that's 216000 Ethereum slots. -//! Assuming 1kB of proofs per block (Stark-to-Groth16 Risc0 & SP1 + SGX, SGX size to be verified) -//! That's only 216MB per month. - -// Imports -// ---------------------------------------------------------------- -use std::{ - fs::File, - path::Path, - sync::{Arc, Once}, -}; - -use chrono::{DateTime, Utc}; -use raiko_core::interfaces::AggregationOnlyRequest; -use raiko_lib::{ - primitives::B256, - prover::{IdStore, IdWrite, ProofKey, ProverError, ProverResult}, -}; -use rusqlite::{ - named_params, {Connection, OpenFlags}, -}; -use tokio::sync::Mutex; - -use crate::{ - ProofTaskDescriptor, TaskDescriptor, TaskManager, TaskManagerError, TaskManagerOpts, - TaskManagerResult, TaskProvingStatus, TaskProvingStatusRecords, TaskReport, TaskStatus, -}; - -// Types -// ---------------------------------------------------------------- - -#[derive(Debug)] -pub struct TaskDb { - conn: Connection, -} - -pub struct SqliteTaskManager { - arc_task_db: Arc>, -} - -// Implementation -// ---------------------------------------------------------------- - -impl TaskDb { - fn open(path: &Path) -> TaskManagerResult { - let conn = Connection::open_with_flags(path, OpenFlags::SQLITE_OPEN_READ_WRITE)?; - conn.pragma_update(None, "foreign_keys", true)?; - conn.pragma_update(None, "locking_mode", "EXCLUSIVE")?; - conn.pragma_update(None, "journal_mode", "WAL")?; - conn.pragma_update(None, "synchronous", "NORMAL")?; - conn.pragma_update(None, "temp_store", "MEMORY")?; - Ok(conn) - } - - fn create(path: &Path) -> TaskManagerResult { - let _file = File::options() - .write(true) - .read(true) - .create_new(true) - .open(path)?; - - let conn = Self::open(path)?; - Self::create_tables(&conn)?; - Self::create_views(&conn)?; - - Ok(conn) - } - - /// Open an existing TaskDb database at "path" - /// If a database does not exist at the path, one is created. - pub fn open_or_create(path: &Path) -> TaskManagerResult { - let conn = if path.exists() { - Self::open(path) - } else { - Self::create(path) - }?; - Ok(Self { conn }) - } - - // SQL - // ---------------------------------------------------------------- - - fn create_tables(conn: &Connection) -> TaskManagerResult<()> { - // Change the task_db_version if backward compatibility is broken - // and introduce a migration on DB opening ... if conserving history is important. - conn.execute_batch( - r#" - -- Key value store - ----------------------------------------------- - CREATE TABLE store( - chain_id INTEGER NOT NULL, - blockhash BLOB NOT NULL, - proofsys_id INTEGER NOT NULL, - id TEXT NOT NULL, - FOREIGN KEY(proofsys_id) REFERENCES proofsys(id), - UNIQUE (chain_id, blockhash, proofsys_id) - ); - - -- Metadata and mappings - ----------------------------------------------- - CREATE TABLE metadata( - key BLOB UNIQUE NOT NULL PRIMARY KEY, - value BLOB - ); - - INSERT INTO - metadata(key, value) - VALUES - ('task_db_version', 0); - - CREATE TABLE proofsys( - id INTEGER UNIQUE NOT NULL PRIMARY KEY, - desc TEXT NOT NULL - ); - - INSERT INTO - proofsys(id, desc) - VALUES - (0, 'Native'), - (1, 'SP1'), - (2, 'SGX'), - (3, 'Risc0'); - - CREATE TABLE status_codes( - id INTEGER UNIQUE NOT NULL PRIMARY KEY, - desc TEXT NOT NULL - ); - - INSERT INTO - status_codes(id, desc) - VALUES - (0, 'Success'), - (1000, 'Registered'), - (2000, 'Work-in-progress'), - (-1000, 'Proof failure (generic)'), - (-1100, 'Proof failure (Out-Of-Memory)'), - (-2000, 'Network failure'), - (-3000, 'Cancelled'), - (-3100, 'Cancelled (never started)'), - (-3200, 'Cancelled (aborted)'), - (-3210, 'Cancellation in progress'), - (-4000, 'Invalid or unsupported block'), - (-9999, 'Unspecified failure reason'); - - -- Data - ----------------------------------------------- - -- Notes: - -- 1. a blockhash may appear as many times as there are prover backends. - -- 2. For query speed over (chain_id, blockhash) - -- there is no need to create an index as the UNIQUE constraint - -- has an implied index, see: - -- - https://sqlite.org/lang_createtable.html#uniqueconst - -- - https://www.sqlite.org/fileformat2.html#representation_of_sql_indices - CREATE TABLE tasks( - id INTEGER UNIQUE NOT NULL PRIMARY KEY, - chain_id INTEGER NOT NULL, - blockhash BLOB NOT NULL, - proofsys_id INTEGER NOT NULL, - prover TEXT NOT NULL, - FOREIGN KEY(proofsys_id) REFERENCES proofsys(id), - UNIQUE (chain_id, blockhash, proofsys_id) - ); - - -- Proofs might also be large, so we isolate them in a dedicated table - CREATE TABLE task_proofs( - task_id INTEGER UNIQUE NOT NULL PRIMARY KEY, - proof TEXT, - FOREIGN KEY(task_id) REFERENCES tasks(id) - ); - - CREATE TABLE task_status( - task_id INTEGER NOT NULL, - status_id INTEGER NOT NULL, - timestamp TIMESTAMP DEFAULT (STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) NOT NULL, - FOREIGN KEY(task_id) REFERENCES tasks(id), - FOREIGN KEY(status_id) REFERENCES status_codes(id), - UNIQUE (task_id, timestamp) - ); - "#, - )?; - - Ok(()) - } - - fn create_views(conn: &Connection) -> TaskManagerResult<()> { - // By convention, views will use an action verb as name. - conn.execute_batch( - r#" - CREATE VIEW enqueue_task AS - SELECT - t.id, - t.chain_id, - t.blockhash, - t.proofsys_id, - t.prover - FROM - tasks t - LEFT JOIN task_status ts on ts.task_id = t.id; - - CREATE VIEW update_task_progress AS - SELECT - t.id, - t.chain_id, - t.blockhash, - t.proofsys_id, - t.prover, - ts.status_id, - tpf.proof - FROM - tasks t - LEFT JOIN task_status ts on ts.task_id = t.id - LEFT JOIN task_proofs tpf on tpf.task_id = t.id; - "#, - )?; - - Ok(()) - } - - /// Set a tracer to debug SQL execution - /// for example: - /// db.set_tracer(Some(|stmt| println!("sqlite:\n-------\n{}\n=======", stmt))); - #[cfg(test)] - #[allow(dead_code)] - pub fn set_tracer(&mut self, trace_fn: Option) { - self.conn.trace(trace_fn); - } - - pub fn manage(&self) -> TaskManagerResult<()> { - // To update all the tables with the task_id assigned by Sqlite - // we require row IDs for the tasks table - // and we use last_insert_rowid() which is not reentrant and need a transaction lock - // and store them in a temporary table, configured to be in-memory. - // - // Alternative approaches considered: - // 1. Sqlite does not support variables (because it's embedded and significantly less overhead than other SQL "Client-Server" DBs). - // 2. using AUTOINCREMENT and/or the sqlite_sequence table - // - sqlite recommends not using AUTOINCREMENT for performance - // https://www.sqlite.org/autoinc.html - // 3. INSERT INTO ... RETURNING nested in a WITH clause (CTE / Common Table Expression) - // - Sqlite can only do RETURNING to the application, it cannot be nested in another query or diverted to another table - // https://sqlite.org/lang_returning.html#limitations_and_caveats - // 4. CREATE TEMPORARY TABLE AS with an INSERT INTO ... RETURNING nested - // - Same limitation AND CREATE TABLEAS seems to only support SELECT statements (but if we could nest RETURNING we can workaround that - // https://www.sqlite.org/lang_createtable.html#create_table_as_select_statements - // - // Hence we have to use row IDs and last_insert_rowid() - // - // Furthermore we use a view and an INSTEAD OF trigger to update the tables, - // the alternative being - // - // 5. Direct insert into tables - // This does not work as SQLite `execute` and `prepare` - // only process the first statement. - // - // And lastly, we need the view and trigger to be temporary because - // otherwise they can't access the temporary table: - // 6. https://sqlite.org/forum/info/4f998eeec510bceee69404541e5c9ca0a301868d59ec7c3486ecb8084309bba1 - // "Triggers in any schema other than temp may only access objects in their own schema. However, triggers in temp may access any object by name, even cross-schema." - self.conn.execute_batch( - r#" - -- PRAGMA temp_store = 'MEMORY'; - CREATE TEMPORARY TABLE IF NOT EXISTS temp.current_task(task_id INTEGER); - - CREATE TEMPORARY TRIGGER IF NOT EXISTS enqueue_task_insert_trigger INSTEAD OF - INSERT - ON enqueue_task - BEGIN - INSERT INTO - tasks(chain_id, blockhash, proofsys_id, prover) - VALUES - ( - new.chain_id, - new.blockhash, - new.proofsys_id, - new.prover - ); - - INSERT INTO - current_task - SELECT - id - FROM - tasks - WHERE - rowid = last_insert_rowid() - LIMIT - 1; - - -- Tasks are initialized at status 1000 - registered - -- timestamp is auto-filled with datetime('now'), see its field definition - INSERT INTO - task_status(task_id, status_id) - SELECT - tmp.task_id, - 1000 - FROM - current_task tmp; - - DELETE FROM - current_task; - END; - - CREATE TEMPORARY TRIGGER IF NOT EXISTS update_task_progress_trigger INSTEAD OF - INSERT - ON update_task_progress - BEGIN - INSERT INTO - current_task - SELECT - id - FROM - tasks - WHERE - chain_id = new.chain_id - AND blockhash = new.blockhash - AND proofsys_id = new.proofsys_id - LIMIT - 1; - - -- timestamp is auto-filled with datetime('now'), see its field definition - INSERT INTO - task_status(task_id, status_id) - SELECT - tmp.task_id, - new.status_id - FROM - current_task tmp - LIMIT - 1; - - INSERT - OR REPLACE INTO task_proofs - SELECT - task_id, - new.proof - FROM - current_task - WHERE - new.proof IS NOT NULL - LIMIT - 1; - - DELETE FROM - current_task; - END; - "#, - )?; - - Ok(()) - } - - pub fn enqueue_task( - &self, - ProofTaskDescriptor { - chain_id, - block_id, - blockhash, - proof_system, - prover, - }: &ProofTaskDescriptor, - ) -> TaskManagerResult { - let mut statement = self.conn.prepare_cached( - r#" - INSERT INTO - enqueue_task( - chain_id, - blockhash, - proofsys_id, - prover - ) - VALUES - ( - :chain_id, - :blockhash, - :proofsys_id, - :prover - ); - "#, - )?; - statement.execute(named_params! { - ":chain_id": chain_id, - ":block_id": block_id, - ":blockhash": blockhash.to_vec(), - ":proofsys_id": *proof_system as u8, - ":prover": prover, - })?; - - Ok(TaskProvingStatusRecords(vec![( - TaskStatus::Registered, - Some(prover.clone()), - Utc::now(), - )])) - } - - pub fn update_task_progress( - &self, - ProofTaskDescriptor { - chain_id, - block_id, - blockhash, - proof_system, - prover, - }: ProofTaskDescriptor, - status: TaskStatus, - proof: Option<&[u8]>, - ) -> TaskManagerResult<()> { - let mut statement = self.conn.prepare_cached( - r#" - INSERT INTO - update_task_progress( - chain_id, - blockhash, - proofsys_id, - status_id, - prover, - proof - ) - VALUES - ( - :chain_id, - :blockhash, - :proofsys_id, - :status_id, - :prover, - :proof - ); - "#, - )?; - statement.execute(named_params! { - ":chain_id": chain_id, - ":blockhash": blockhash.to_vec(), - ":proofsys_id": proof_system as u8, - ":prover": prover, - ":status_id": i32::from(status), - ":proof": proof.map(hex::encode) - })?; - - Ok(()) - } - - pub fn get_task_proving_status( - &self, - ProofTaskDescriptor { - chain_id, - block_id, - blockhash, - proof_system, - prover, - }: &ProofTaskDescriptor, - ) -> TaskManagerResult { - let mut statement = self.conn.prepare_cached( - r#" - SELECT - ts.status_id, - tp.proof, - timestamp - FROM - task_status ts - LEFT JOIN tasks t ON ts.task_id = t.id - LEFT JOIN task_proofs tp ON tp.task_id = t.id - WHERE - t.chain_id = :chain_id - AND t.blockhash = :blockhash - AND t.proofsys_id = :proofsys_id - AND t.prover = :prover - ORDER BY - ts.timestamp; - "#, - )?; - let query = statement.query_map( - named_params! { - ":chain_id": chain_id, - ":block_id": block_id, - ":blockhash": blockhash.to_vec(), - ":proofsys_id": *proof_system as u8, - ":prover": prover, - }, - |row| { - Ok(( - TaskStatus::from(row.get::<_, i32>(0)?), - row.get::<_, Option>(1)?, - row.get::<_, DateTime>(2)?, - )) - }, - )?; - - Ok(TaskProvingStatusRecords( - query.collect::, _>>()?, - )) - } - - pub fn get_task_proof( - &self, - ProofTaskDescriptor { - chain_id, - block_id, - blockhash, - proof_system, - prover, - }: &ProofTaskDescriptor, - ) -> TaskManagerResult> { - let mut statement = self.conn.prepare_cached( - r#" - SELECT - proof - FROM - task_proofs tp - LEFT JOIN tasks t ON tp.task_id = t.id - WHERE - t.chain_id = :chain_id - AND t.prover = :prover - AND t.blockhash = :blockhash - AND t.proofsys_id = :proofsys_id - LIMIT - 1; - "#, - )?; - let query = statement.query_row( - named_params! { - ":chain_id": chain_id, - ":block_id": block_id, - ":blockhash": blockhash.to_vec(), - ":proofsys_id": *proof_system as u8, - ":prover": prover, - }, - |row| row.get::<_, Option>(0), - )?; - - let Some(proof) = query else { - return Ok(vec![]); - }; - - hex::decode(proof) - .map_err(|_| TaskManagerError::SqlError("couldn't decode from hex".to_owned())) - } - - pub fn get_db_size(&self) -> TaskManagerResult<(usize, Vec<(String, usize)>)> { - let mut statement = self.conn.prepare_cached( - r#" - SELECT - name as table_name, - SUM(pgsize) as table_size - FROM - dbstat - GROUP BY - table_name - ORDER BY - SUM(pgsize) DESC; - "#, - )?; - let query = statement.query_map([], |row| Ok((row.get(0)?, row.get(1)?)))?; - let details = query.collect::, _>>()?; - let total = details.iter().fold(0, |acc, (_, size)| acc + size); - - Ok((total, details)) - } - - pub fn prune_db(&self) -> TaskManagerResult<()> { - let mut statement = self.conn.prepare_cached( - r#" - DELETE FROM - tasks; - - DELETE FROM - task_proofs; - - DELETE FROM - task_status; - "#, - )?; - statement.execute([])?; - - Ok(()) - } - - pub fn list_all_tasks(&self) -> TaskManagerResult> { - let mut statement = self.conn.prepare_cached( - r#" - SELECT - chain_id, - blockhash, - proofsys_id, - prover, - status_id - FROM - tasks - LEFT JOIN task_status on task.id = task_status.task_id - JOIN ( - SELECT - task_id, - MAX(timestamp) as latest_timestamp - FROM - task_status - GROUP BY - task_id - ) latest_ts ON task_status.task_id = latest_ts.task_id - AND task_status.timestamp = latest_ts.latest_timestamp - "#, - )?; - let query = statement - .query_map([], |row| { - Ok(( - TaskDescriptor::SingleProof(ProofTaskDescriptor { - chain_id: row.get(0)?, - block_id: row.get(1)?, - blockhash: B256::from_slice(&row.get::<_, Vec>(1)?), - proof_system: row.get::<_, u8>(2)?.try_into().unwrap(), - prover: row.get(3)?, - }), - TaskStatus::from(row.get::<_, i32>(4)?), - )) - })? - .collect::, _>>()?; - - Ok(query) - } - - fn list_stored_ids(&self) -> TaskManagerResult> { - unimplemented!() - } - - fn store_id( - &self, - (chain_id, block_id, blockhash, proof_key): ProofKey, - id: String, - ) -> TaskManagerResult<()> { - let mut statement = self.conn.prepare_cached( - r#" - INSERT INTO - store( - chain_id, - blockhash, - proofsys_id, - id - ) - VALUES - ( - :chain_id, - :blockhash, - :proofsys_id, - :id - ); - "#, - )?; - statement.execute(named_params! { - ":chain_id": chain_id, - ":blockhash": blockhash.to_vec(), - ":proofsys_id": proof_key, - ":id": id, - })?; - - Ok(()) - } - - fn remove_id( - &self, - (chain_id, block_id, blockhash, proof_key): ProofKey, - ) -> TaskManagerResult<()> { - let mut statement = self.conn.prepare_cached( - r#" - DELETE FROM - store - WHERE - chain_id = :chain_id - AND blockhash = :blockhash - AND proofsys_id = :proofsys_id; - "#, - )?; - statement.execute(named_params! { - ":chain_id": chain_id, - ":blockhash": blockhash.to_vec(), - ":proofsys_id": proof_key, - })?; - - Ok(()) - } - - fn read_id( - &self, - (chain_id, block_id, blockhash, proof_key): ProofKey, - ) -> TaskManagerResult { - let mut statement = self.conn.prepare_cached( - r#" - SELECT - id - FROM - store - WHERE - chain_id = :chain_id - AND blockhash = :blockhash - AND proofsys_id = :proofsys_id - LIMIT - 1; - "#, - )?; - let query = match statement.query_row( - named_params! { - ":chain_id": chain_id, - ":blockhash": blockhash.to_vec(), - ":proofsys_id": proof_key, - }, - |row| row.get::<_, String>(0), - ) { - Ok(q) => q, - Err(e) => { - return match e { - rusqlite::Error::QueryReturnedNoRows => Err(TaskManagerError::NoData), - e => Err(e.into()), - } - } - }; - - Ok(query) - } -} - -#[async_trait::async_trait] -impl IdWrite for SqliteTaskManager { - async fn store_id(&mut self, key: ProofKey, id: String) -> ProverResult<()> { - let task_db = self.arc_task_db.lock().await; - task_db - .store_id(key, id) - .map_err(|e| ProverError::StoreError(e.to_string())) - } - - async fn remove_id(&mut self, key: ProofKey) -> ProverResult<()> { - let task_db = self.arc_task_db.lock().await; - task_db - .remove_id(key) - .map_err(|e| ProverError::StoreError(e.to_string())) - } -} - -#[async_trait::async_trait] -impl IdStore for SqliteTaskManager { - async fn read_id(&self, key: ProofKey) -> ProverResult { - let task_db = self.arc_task_db.lock().await; - task_db - .read_id(key) - .map_err(|e| ProverError::StoreError(e.to_string())) - } -} - -#[async_trait::async_trait] -impl TaskManager for SqliteTaskManager { - fn new(opts: &TaskManagerOpts) -> Self { - static INIT: Once = Once::new(); - static mut CONN: Option>> = None; - INIT.call_once(|| { - unsafe { - CONN = Some(Arc::new(Mutex::new({ - let db = TaskDb::open_or_create(&opts.sqlite_file).unwrap(); - db.manage().unwrap(); - db - }))) - }; - }); - Self { - arc_task_db: unsafe { CONN.clone().unwrap() }, - } - } - - async fn enqueue_task( - &mut self, - params: &ProofTaskDescriptor, - ) -> Result { - let task_db: tokio::sync::MutexGuard<'_, TaskDb> = self.arc_task_db.lock().await; - task_db.enqueue_task(params) - } - - async fn update_task_progress( - &mut self, - key: ProofTaskDescriptor, - status: TaskStatus, - proof: Option<&[u8]>, - ) -> TaskManagerResult<()> { - let task_db = self.arc_task_db.lock().await; - task_db.update_task_progress(key, status, proof) - } - - /// Returns the latest triplet (submitter or fulfiller, status, last update time) - async fn get_task_proving_status( - &mut self, - key: &ProofTaskDescriptor, - ) -> TaskManagerResult { - let task_db = self.arc_task_db.lock().await; - task_db.get_task_proving_status(key) - } - - async fn get_task_proof(&mut self, key: &ProofTaskDescriptor) -> TaskManagerResult> { - let task_db = self.arc_task_db.lock().await; - task_db.get_task_proof(key) - } - - /// Returns the total and detailed database size - async fn get_db_size(&mut self) -> TaskManagerResult<(usize, Vec<(String, usize)>)> { - let task_db = self.arc_task_db.lock().await; - task_db.get_db_size() - } - - async fn prune_db(&mut self) -> TaskManagerResult<()> { - let task_db = self.arc_task_db.lock().await; - task_db.prune_db() - } - - async fn list_all_tasks(&mut self) -> TaskManagerResult> { - let task_db = self.arc_task_db.lock().await; - task_db.list_all_tasks() - } - - async fn list_stored_ids(&mut self) -> TaskManagerResult> { - let task_db = self.arc_task_db.lock().await; - task_db.list_stored_ids() - } - - async fn enqueue_aggregation_task( - &mut self, - _request: &AggregationOnlyRequest, - ) -> TaskManagerResult<()> { - todo!() - } - - async fn get_aggregation_task_proving_status( - &mut self, - _request: &AggregationOnlyRequest, - ) -> TaskManagerResult { - todo!() - } - - async fn update_aggregation_task_progress( - &mut self, - _request: &AggregationOnlyRequest, - _status: TaskStatus, - _proof: Option<&[u8]>, - ) -> TaskManagerResult<()> { - todo!() - } - - async fn get_aggregation_task_proof( - &mut self, - _request: &AggregationOnlyRequest, - ) -> TaskManagerResult> { - todo!() - } -} - -#[cfg(test)] -mod tests { - // We only test private functions here. - // Public API will be tested in a dedicated tests folder - - use super::*; - use tempfile::tempdir; - - #[test] - fn error_on_missing() { - let dir = tempdir().unwrap(); - let file = dir.path().join("db.sqlite"); - assert!(TaskDb::open(&file).is_err()); - } - - #[test] - fn ensure_exclusive() { - let dir = tempdir().unwrap(); - let file = dir.path().join("db.sqlite"); - - let _db = TaskDb::create(&file).unwrap(); - assert!(TaskDb::open(&file).is_err()); - std::fs::remove_file(&file).unwrap(); - } - - #[test] - fn ensure_unicity() { - let dir = tempdir().unwrap(); - let file = dir.path().join("db.sqlite"); - - let _db = TaskDb::create(&file).unwrap(); - assert!(TaskDb::create(&file).is_err()); - std::fs::remove_file(&file).unwrap(); - } -} diff --git a/taskdb/src/lib.rs b/taskdb/src/lib.rs index 1e7d517d..1f630b04 100644 --- a/taskdb/src/lib.rs +++ b/taskdb/src/lib.rs @@ -1,29 +1,21 @@ -use std::{ - io::{Error as IOError, ErrorKind as IOErrorKind}, - path::PathBuf, -}; +use std::io::{Error as IOError, ErrorKind as IOErrorKind}; use chrono::{DateTime, Utc}; -use raiko_core::interfaces::{AggregationOnlyRequest, ProofType}; +use raiko_core::interfaces::AggregationOnlyRequest; use raiko_lib::{ primitives::{ChainId, B256}, + proof_type::ProofType, prover::{IdStore, IdWrite, ProofKey, ProverResult}, }; -#[cfg(feature = "sqlite")] -use rusqlite::Error as RustSQLiteError; use serde::{Deserialize, Serialize}; use tracing::debug; use utoipa::ToSchema; -#[cfg(feature = "sqlite")] -use crate::adv_sqlite::SqliteTaskManager; #[cfg(feature = "in-memory")] use crate::mem_db::InMemoryTaskManager; #[cfg(feature = "redis-db")] use crate::redis_db::RedisTaskManager; -#[cfg(feature = "sqlite")] -mod adv_sqlite; #[cfg(feature = "in-memory")] mod mem_db; #[cfg(feature = "redis-db")] @@ -35,8 +27,6 @@ mod redis_db; pub enum TaskManagerError { #[error("IO Error {0}")] IOError(IOErrorKind), - #[error("SQL Error {0}")] - SqlError(String), #[cfg(feature = "redis-db")] #[error("Redis Error {0}")] RedisError(#[from] crate::redis_db::RedisDbError), @@ -54,19 +44,6 @@ impl From for TaskManagerError { } } -#[cfg(feature = "sqlite")] -impl From for TaskManagerError { - fn from(error: RustSQLiteError) -> TaskManagerError { - TaskManagerError::SqlError(error.to_string()) - } -} - -impl From for TaskManagerError { - fn from(error: serde_json::Error) -> TaskManagerError { - TaskManagerError::SqlError(error.to_string()) - } -} - impl From for TaskManagerError { fn from(value: anyhow::Error) -> Self { TaskManagerError::Anyhow(value.to_string()) @@ -236,7 +213,6 @@ pub type TaskReport = (TaskDescriptor, TaskStatus); #[derive(Debug, Clone, Default)] pub struct TaskManagerOpts { - pub sqlite_file: PathBuf, pub max_db_size: usize, pub redis_url: String, pub redis_ttl: u64, @@ -425,8 +401,6 @@ impl TaskManager for TaskManagerWrapper { #[cfg(feature = "in-memory")] pub type TaskManagerWrapperImpl = TaskManagerWrapper; -#[cfg(feature = "sqlite")] -pub type TaskManagerWrapperImpl = TaskManagerWrapper; #[cfg(feature = "redis-db")] pub type TaskManagerWrapperImpl = TaskManagerWrapper; @@ -438,29 +412,23 @@ pub fn get_task_manager(opts: &TaskManagerOpts) -> TaskManagerWrapperImpl { #[cfg(test)] mod test { use super::*; - use std::path::Path; + use rand::Rng; #[tokio::test] async fn test_new_taskmanager() { - let sqlite_file: &Path = Path::new("test.db"); - // remove existed one - if sqlite_file.exists() { - std::fs::remove_file(sqlite_file).unwrap(); - } - let opts = TaskManagerOpts { - sqlite_file: sqlite_file.to_path_buf(), max_db_size: 1024 * 1024, redis_url: "redis://localhost:6379".to_string(), redis_ttl: 3600, }; let mut task_manager = get_task_manager(&opts); + let block_id = rand::thread_rng().gen_range(0..1000000); assert_eq!( task_manager .enqueue_task(&ProofTaskDescriptor { chain_id: 1, - block_id: 0, + block_id, blockhash: B256::default(), proof_system: ProofType::Native, prover: "test".to_string(), @@ -475,22 +443,16 @@ mod test { #[tokio::test] async fn test_enqueue_twice() { - let sqlite_file: &Path = Path::new("test.db"); - // remove existed one - if sqlite_file.exists() { - std::fs::remove_file(sqlite_file).unwrap(); - } - let opts = TaskManagerOpts { - sqlite_file: sqlite_file.to_path_buf(), max_db_size: 1024 * 1024, redis_url: "redis://localhost:6379".to_string(), redis_ttl: 3600, }; let mut task_manager = get_task_manager(&opts); + let block_id = rand::thread_rng().gen_range(0..1000000); let key = ProofTaskDescriptor { chain_id: 1, - block_id: 0, + block_id, blockhash: B256::default(), proof_system: ProofType::Native, prover: "test".to_string(), diff --git a/taskdb/src/mem_db.rs b/taskdb/src/mem_db.rs index 508c819d..8b6b41b7 100644 --- a/taskdb/src/mem_db.rs +++ b/taskdb/src/mem_db.rs @@ -98,21 +98,21 @@ impl InMemoryTaskDb { let proving_status_records = self .tasks_queue .get(key) - .ok_or_else(|| TaskManagerError::SqlError("no task in db".to_owned()))?; + .ok_or_else(|| TaskManagerError::Anyhow("no task in db".to_owned()))?; let (_, proof, ..) = proving_status_records .0 .iter() .filter(|(status, ..)| (status == &TaskStatus::Success)) .last() - .ok_or_else(|| TaskManagerError::SqlError("no successful task in db".to_owned()))?; + .ok_or_else(|| TaskManagerError::Anyhow("no successful task in db".to_owned()))?; let Some(proof) = proof else { return Ok(vec![]); }; hex::decode(proof) - .map_err(|_| TaskManagerError::SqlError("couldn't decode from hex".to_owned())) + .map_err(|_| TaskManagerError::Anyhow("couldn't decode from hex".to_owned())) } fn size(&mut self) -> TaskManagerResult<(usize, Vec<(String, usize)>)> { @@ -238,21 +238,21 @@ impl InMemoryTaskDb { let proving_status_records = self .aggregation_tasks_queue .get(request) - .ok_or_else(|| TaskManagerError::SqlError("no task in db".to_owned()))?; + .ok_or_else(|| TaskManagerError::Anyhow("no task in db".to_owned()))?; let (_, proof, ..) = proving_status_records .0 .iter() .filter(|(status, ..)| (status == &TaskStatus::Success)) .last() - .ok_or_else(|| TaskManagerError::SqlError("no successful task in db".to_owned()))?; + .ok_or_else(|| TaskManagerError::Anyhow("no successful task in db".to_owned()))?; let Some(proof) = proof else { return Ok(vec![]); }; hex::decode(proof) - .map_err(|_| TaskManagerError::SqlError("couldn't decode from hex".to_owned())) + .map_err(|_| TaskManagerError::Anyhow("couldn't decode from hex".to_owned())) } } diff --git a/taskdb/tests/main.rs b/taskdb/tests/main.rs index 13134e34..74b2a91b 100644 --- a/taskdb/tests/main.rs +++ b/taskdb/tests/main.rs @@ -10,11 +10,11 @@ mod tests { use std::{collections::HashMap, env, time::Duration}; use alloy_primitives::Address; - use raiko_core::interfaces::{ProofRequest, ProofType}; + use raiko_core::interfaces::ProofRequest; use rand::{Rng, SeedableRng}; use rand_chacha::ChaCha8Rng; - use raiko_lib::{input::BlobProofType, primitives::B256}; + use raiko_lib::{input::BlobProofType, primitives::B256, proof_type::ProofType}; use raiko_tasks::{ get_task_manager, ProofTaskDescriptor, TaskManager, TaskManagerOpts, TaskStatus, }; @@ -52,20 +52,7 @@ mod tests { #[tokio::test] async fn test_enqueue_task() { - // // Materialized local DB - // let dir = std::env::current_dir().unwrap().join("tests"); - // let file = dir.as_path().join("test_enqueue_task.sqlite"); - // if file.exists() { - // std::fs::remove_file(&file).unwrap() - // }; - - // temp dir DB - use tempfile::tempdir; - let dir = tempdir().unwrap(); - let file = dir.path().join("test_enqueue_task.sqlite"); - let mut tama = get_task_manager(&TaskManagerOpts { - sqlite_file: file, max_db_size: 1_000_000, redis_url: env::var("REDIS_URL").unwrap_or_default(), redis_ttl: 3600, @@ -89,22 +76,7 @@ mod tests { #[tokio::test] async fn test_update_query_tasks_progress() { - // Materialized local DB - let dir = std::env::current_dir().unwrap().join("tests"); - let file = dir - .as_path() - .join("test_update_query_tasks_progress.sqlite"); - if file.exists() { - std::fs::remove_file(&file).unwrap() - }; - - // // temp dir DB - // use tempfile::tempdir; - // let dir = tempdir().unwrap(); - // let file = dir.path().join("test_update_task_progress.sqlite"); - let mut tama = get_task_manager(&TaskManagerOpts { - sqlite_file: file, max_db_size: 1_000_000, redis_url: env::var("REDIS_URL").unwrap_or_default(), redis_ttl: 3600,