From 2b37fbd1e09c8e030b117af79c84ad7fdcc41bf6 Mon Sep 17 00:00:00 2001 From: Zach Brown Date: Thu, 22 Feb 2024 16:12:59 -0800 Subject: [PATCH] Implement on-demand recursive circuit table loading (#21) * Implement on-demand recursive circuit table loading * Update common/src/prover_state/mod.rs Co-authored-by: Robin Salen <30937548+Nashtare@users.noreply.github.com> * Update common/src/prover_state/mod.rs Co-authored-by: Robin Salen <30937548+Nashtare@users.noreply.github.com> * Update common/src/prover_state/persistence.rs Co-authored-by: Robin Salen <30937548+Nashtare@users.noreply.github.com> * address PR feedback * Make table load strategy configurable from CLI --------- Co-authored-by: Robin Salen <30937548+Nashtare@users.noreply.github.com> --- Cargo.lock | 10 +- common/Cargo.toml | 2 + common/src/prover_state/circuit.rs | 47 +++- common/src/prover_state/cli.rs | 47 +++- common/src/prover_state/mod.rs | 355 +++++++++++++++++++------ common/src/prover_state/persistence.rs | 310 ++++++++++++++------- leader/src/main.rs | 14 +- ops/src/lib.rs | 17 +- verifier/src/main.rs | 8 +- worker/src/main.rs | 9 +- 10 files changed, 602 insertions(+), 217 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9550dd49..f3560cc4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -600,11 +600,13 @@ checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" name = "common" version = "0.1.0" dependencies = [ + "anyhow", "clap", "evm_arithmetization", "plonky2", "proof_gen", "thiserror", + "trace_decoder", "tracing", ] @@ -2736,18 +2738,18 @@ checksum = "836fa6a3e1e547f9a2c4040802ec865b5d85f4014efe00555d7090a3dcaa1090" [[package]] name = "serde" -version = "1.0.190" +version = "1.0.193" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91d3c334ca1ee894a2c6f6ad698fe8c435b76d504b13d436f0685d648d6d96f7" +checksum = "25dd9975e68d0cb5aa1120c288333fc98731bd1dd12f561e468ea4728c042b89" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.190" +version = "1.0.193" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67c5609f394e5c2bd7fc51efda478004ea80ef42fee983d5c67a65e34f32c0e3" +checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3" dependencies = [ "proc-macro2", "quote", diff --git a/common/Cargo.toml b/common/Cargo.toml index 909a0a27..8dc464d5 100644 --- a/common/Cargo.toml +++ b/common/Cargo.toml @@ -15,3 +15,5 @@ proof_gen = { workspace = true } plonky2 = { workspace = true } evm_arithmetization = { workspace = true } clap = { workspace = true } +anyhow = { workspace = true } +trace_decoder = { workspace = true } diff --git a/common/src/prover_state/circuit.rs b/common/src/prover_state/circuit.rs index 5813b688..6ab1a03a 100644 --- a/common/src/prover_state/circuit.rs +++ b/common/src/prover_state/circuit.rs @@ -11,7 +11,9 @@ use proof_gen::types::AllRecursiveCircuits; use crate::parsing::{parse_range, RangeParseError}; /// Number of tables defined in plonky2. -const NUM_TABLES: usize = 7; +/// +/// TODO: This should be made public in the evm_arithmetization crate. +pub(crate) const NUM_TABLES: usize = 7; /// New type wrapper for [`Range`] that implements [`FromStr`] and [`Display`]. /// @@ -111,6 +113,19 @@ impl Circuit { Circuit::Memory => "memory", } } + + /// Get the circuit name as a short str literal. + pub const fn as_short_str(&self) -> &'static str { + match self { + Circuit::Arithmetic => "a", + Circuit::BytePacking => "bp", + Circuit::Cpu => "c", + Circuit::Keccak => "k", + Circuit::KeccakSponge => "ks", + Circuit::Logic => "l", + Circuit::Memory => "m", + } + } } impl From for Circuit { @@ -128,11 +143,27 @@ impl From for Circuit { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct CircuitConfig { circuits: [Range; NUM_TABLES], } +impl std::ops::Index for CircuitConfig { + type Output = Range; + + fn index(&self, index: usize) -> &Self::Output { + &self.circuits[index] + } +} + +impl std::ops::Index for CircuitConfig { + type Output = Range; + + fn index(&self, index: Circuit) -> &Self::Output { + &self.circuits[index as usize] + } +} + impl Default for CircuitConfig { fn default() -> Self { Self { @@ -176,16 +207,8 @@ impl CircuitConfig { /// Get a unique string representation of the config. pub fn get_configuration_digest(&self) -> String { self.enumerate() - .map(|(circuit, range)| match circuit { - Circuit::Arithmetic => format!("a_{}-{}", range.start, range.end), - Circuit::BytePacking => format!("b_p_{}-{}", range.start, range.end), - Circuit::Cpu => format!("c_{}-{}", range.start, range.end), - Circuit::Keccak => format!("k_{}-{}", range.start, range.end), - Circuit::KeccakSponge => { - format!("k_s_{}-{}", range.start, range.end) - } - Circuit::Logic => format!("l_{}-{}", range.start, range.end), - Circuit::Memory => format!("m_{}-{}", range.start, range.end), + .map(|(circuit, range)| { + format!("{}_{}-{}", circuit.as_short_str(), range.start, range.end) }) .fold(String::new(), |mut acc, s| { if !acc.is_empty() { diff --git a/common/src/prover_state/cli.rs b/common/src/prover_state/cli.rs index 51e067df..5a78a186 100644 --- a/common/src/prover_state/cli.rs +++ b/common/src/prover_state/cli.rs @@ -1,10 +1,12 @@ //! CLI arguments for constructing a [`CircuitConfig`], which can be used to //! construct table circuits. -use clap::Args; +use std::fmt::Display; + +use clap::{Args, ValueEnum}; use super::{ circuit::{Circuit, CircuitConfig, CircuitSize}, - CircuitPersistence, ProverStateConfig, + ProverStateManager, TableLoadStrategy, }; /// The help heading for the circuit arguments. @@ -21,6 +23,33 @@ fn circuit_arg_desc(circuit_name: &str) -> String { format!("The min/max size for the {circuit_name} table circuit.") } +/// Specifies whether to persist the processed circuits. +#[derive(Debug, Clone, Copy, ValueEnum)] +pub enum CircuitPersistence { + /// Do not persist the processed circuits. + None, + /// Persist the processed circuits to disk. + Disk, +} + +impl CircuitPersistence { + pub fn with_load_strategy(self, load_strategy: TableLoadStrategy) -> super::CircuitPersistence { + match self { + CircuitPersistence::None => super::CircuitPersistence::None, + CircuitPersistence::Disk => super::CircuitPersistence::Disk(load_strategy), + } + } +} + +impl Display for CircuitPersistence { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + CircuitPersistence::None => write!(f, "none"), + CircuitPersistence::Disk => write!(f, "disk"), + } + } +} + /// Macro for generating the [`CliCircuitConfig`] struct. macro_rules! gen_prover_state_config { ($($name:ident: $circuit:expr),*) => { @@ -28,6 +57,8 @@ macro_rules! gen_prover_state_config { pub struct CliProverStateConfig { #[clap(long, help_heading = HEADING, default_value_t = CircuitPersistence::Disk)] pub persistence: CircuitPersistence, + #[clap(long, help_heading = HEADING, default_value_t = TableLoadStrategy::OnDemand)] + pub load_strategy: TableLoadStrategy, $( #[clap( @@ -73,16 +104,16 @@ impl CliProverStateConfig { config } - pub fn into_prover_state_config(self) -> ProverStateConfig { - ProverStateConfig { - persistence: self.persistence, + pub fn into_prover_state_manager(self) -> ProverStateManager { + ProverStateManager { + persistence: self.persistence.with_load_strategy(self.load_strategy), circuit_config: self.into_circuit_config(), } } } -impl From for ProverStateConfig { - fn from(item: CliProverStateConfig) -> Self { - item.into_prover_state_config() +impl From for ProverStateManager { + fn from(config: CliProverStateConfig) -> Self { + config.into_prover_state_manager() } } diff --git a/common/src/prover_state/mod.rs b/common/src/prover_state/mod.rs index c9af6c2b..090201f1 100644 --- a/common/src/prover_state/mod.rs +++ b/common/src/prover_state/mod.rs @@ -2,23 +2,48 @@ //! //! This module provides the following: //! - [`Circuit`] and [`CircuitConfig`] which can be used to dynamically -//! construct [`AllRecursiveCircuits`] from the specified circuit sizes. +//! construct [`evm_arithmetization::fixed_recursive_verifier::AllRecursiveCircuits`] +//! from the specified circuit sizes. //! - Command line arguments for constructing a [`CircuitConfig`]. //! - Provides default values for the circuit sizes. //! - Allows the circuit sizes to be specified via environment variables. -//! - Persistence utilities for saving and loading [`AllRecursiveCircuits`]. +//! - Persistence utilities for saving and loading +//! [`evm_arithmetization::fixed_recursive_verifier::AllRecursiveCircuits`]. //! - Global prover state management via the [`P_STATE`] static and the //! [`set_prover_state_from_config`] function. use std::{fmt::Display, sync::OnceLock}; use clap::ValueEnum; -use proof_gen::{prover_state::ProverState, VerifierState}; +use evm_arithmetization::{proof::AllProof, prover::prove, AllStark, StarkConfig}; +use plonky2::{ + field::goldilocks_field::GoldilocksField, plonk::config::PoseidonGoldilocksConfig, + util::timing::TimingTree, +}; +use proof_gen::{proof_types::GeneratedTxnProof, prover_state::ProverState, VerifierState}; +use trace_decoder::types::TxnProofGenIR; use tracing::info; +use self::circuit::{CircuitConfig, NUM_TABLES}; +use crate::prover_state::persistence::{ + BaseProverResource, DiskResource, MonolithicProverResource, RecursiveCircuitResource, + VerifierResource, +}; + pub mod circuit; pub mod cli; pub mod persistence; +pub(crate) type Config = PoseidonGoldilocksConfig; +pub(crate) type Field = GoldilocksField; +pub(crate) const SIZE: usize = 2; + +pub(crate) type RecursiveCircuitsForTableSize = + evm_arithmetization::fixed_recursive_verifier::RecursiveCircuitsForTableSize< + Field, + Config, + SIZE, + >; + /// The global prover state. /// /// It is specified as a `OnceLock` for the following reasons: @@ -28,107 +53,281 @@ pub mod persistence; /// - This scheme works for both a cluster and a single machine. In particular, /// whether imported from a worker node or a thread in the leader node /// (in-memory mode), the prover state is initialized only once. -pub static P_STATE: OnceLock = OnceLock::new(); +static P_STATE: OnceLock = OnceLock::new(); + +/// The global prover state manager. +/// +/// Unlike the prover state, the prover state manager houses configuration and +/// persistence information. This allows it to differentiate between the +/// different transaction proof generation strategies. As such, it is generally +/// only necessary when generating transaction proofs. +/// +/// It's specified as a `OnceLock` for the same reasons as the prover state. +static MANAGER: OnceLock = OnceLock::new(); + +pub fn p_state() -> &'static ProverState { + P_STATE.get().expect("Prover state is not initialized") +} + +pub fn p_manager() -> &'static ProverStateManager { + MANAGER + .get() + .expect("Prover state manager is not initialized") +} + +/// Specifies how to load the table circuits. +#[derive(Debug, Clone, Copy, Default, ValueEnum)] +pub enum TableLoadStrategy { + #[default] + /// Load the circuit tables as needed for shrinking STARK proofs. + /// + /// - Generate a STARK proof. + /// - Compute the degree bits. + /// - Load the necessary table circuits. + OnDemand, + /// Load all the table circuits into a monolithic bundle. + Monolithic, +} + +impl Display for TableLoadStrategy { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TableLoadStrategy::OnDemand => write!(f, "on-demand"), + TableLoadStrategy::Monolithic => write!(f, "monolithic"), + } + } +} /// Specifies whether to persist the processed circuits. -#[derive(Debug, Clone, Copy, ValueEnum)] +#[derive(Debug, Clone, Copy)] pub enum CircuitPersistence { /// Do not persist the processed circuits. None, /// Persist the processed circuits to disk. - Disk, + Disk(TableLoadStrategy), } -impl Display for CircuitPersistence { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - CircuitPersistence::None => write!(f, "none"), - CircuitPersistence::Disk => write!(f, "disk"), - } +impl Default for CircuitPersistence { + fn default() -> Self { + CircuitPersistence::Disk(TableLoadStrategy::default()) } } /// Product of [`CircuitConfig`] and [`CircuitPersistence`]. -#[derive(Debug)] -pub struct ProverStateConfig { - pub circuit_config: circuit::CircuitConfig, +/// +/// Provides helper utilities for interacting with the prover state in +/// accordance with the specified configuration and persistence strategy. +#[derive(Default, Debug, Clone)] +pub struct ProverStateManager { + pub circuit_config: CircuitConfig, pub persistence: CircuitPersistence, } -/// Initializes the global prover state. -pub fn set_prover_state_from_config( - ProverStateConfig { - circuit_config, - persistence, - }: ProverStateConfig, -) -> Result<(), ProverState> { - info!("initializing prover state..."); - let state = match persistence { - CircuitPersistence::None => { - info!("generating circuits..."); - ProverState { - state: circuit_config.as_all_recursive_circuits(), +impl ProverStateManager { + pub fn with_load_strategy(self, load_strategy: TableLoadStrategy) -> Self { + match self.persistence { + CircuitPersistence::None => self, + CircuitPersistence::Disk(_) => Self { + circuit_config: self.circuit_config, + persistence: CircuitPersistence::Disk(load_strategy), + }, + } + } + + /// Load the table circuits necessary to shrink the STARK proof. + /// + /// [`AllProof`] provides the necessary degree bits for each circuit via the + /// [`AllProof::degree_bits`] method. + /// Using this information, for each circuit, a tuple is returned, + /// containing: + /// 1. The loaded table circuit at the specified size. + /// 2. An offset indicating the position of the specified size within the + /// configured range used when pre-generating the circuits. + fn load_table_circuits( + &self, + config: &StarkConfig, + all_proof: &AllProof, + ) -> anyhow::Result<[(RecursiveCircuitsForTableSize, u8); NUM_TABLES]> { + let degrees = all_proof.degree_bits(config); + + /// Given a recursive circuit index (e.g., Arithmetic / 0), return a + /// tuple containing the loaded table at the specified size and + /// its offset relative to the configured range used to pre-process the + /// circuits. + macro_rules! circuit { + ($circuit_index:expr) => { + ( + RecursiveCircuitResource::get(&( + $circuit_index.into(), + degrees[$circuit_index], + )) + .map_err(|e| { + let circuit: $crate::prover_state::circuit::Circuit = $circuit_index.into(); + let size = degrees[$circuit_index]; + anyhow::Error::from(e).context(format!( + "Attempting to load circuit: {circuit:?} at size: {size}" + )) + })?, + (degrees[$circuit_index] - self.circuit_config[$circuit_index].start) as u8, + ) + }; + } + + Ok([ + circuit!(0), + circuit!(1), + circuit!(2), + circuit!(3), + circuit!(4), + circuit!(5), + circuit!(6), + ]) + } + + /// Generate a transaction proof using the specified input, loading the + /// circuit tables as needed to shrink the individual STARK proofs, and + /// finally aggregating them to a final transaction proof. + fn txn_proof_on_demand(&self, input: TxnProofGenIR) -> anyhow::Result { + let config = StarkConfig::standard_fast_config(); + let all_stark = AllStark::default(); + let all_proof = prove(&all_stark, &config, input, &mut TimingTree::default(), None)?; + + let table_circuits = self.load_table_circuits(&config, &all_proof)?; + + let (intern, p_vals) = + p_state() + .state + .prove_root_after_initial_stark(all_proof, &table_circuits, None)?; + + Ok(GeneratedTxnProof { intern, p_vals }) + } + + /// Generate a transaction proof using the specified input on the monolithic + /// circuit. + fn txn_proof_monolithic(&self, input: TxnProofGenIR) -> anyhow::Result { + let (intern, p_vals) = p_state().state.prove_root( + &AllStark::default(), + &StarkConfig::standard_fast_config(), + input, + &mut TimingTree::default(), + None, + )?; + + Ok(GeneratedTxnProof { p_vals, intern }) + } + + /// Generate a transaction proof using the specified input. + /// + /// The specific implementation depends on the persistence strategy. + /// - If the persistence strategy is [`CircuitPersistence::None`] or + /// [`CircuitPersistence::Disk`] with [`TableLoadStrategy::Monolithic`], + /// the monolithic circuit is used. + /// - If the persistence strategy is [`CircuitPersistence::Disk`] with + /// [`TableLoadStrategy::OnDemand`], the table circuits are loaded as + /// needed. + pub fn generate_txn_proof(&self, input: TxnProofGenIR) -> anyhow::Result { + match self.persistence { + CircuitPersistence::None | CircuitPersistence::Disk(TableLoadStrategy::Monolithic) => { + info!("using monolithic circuit {:?}", self); + self.txn_proof_monolithic(input) + } + CircuitPersistence::Disk(TableLoadStrategy::OnDemand) => { + info!("using on demand circuit {:?}", self); + self.txn_proof_on_demand(input) } } - CircuitPersistence::Disk => { - info!("attempting to load preprocessed circuits from disk..."); - let disk_state = persistence::prover_from_disk(&circuit_config); - match disk_state { - Some(circuits) => { - info!("successfully loaded preprocessed circuits from disk"); - ProverState { state: circuits } + } + + /// Initialize global prover state from the configuration. + pub fn initialize(&self) -> anyhow::Result<()> { + info!("initializing prover state..."); + + let state = match self.persistence { + CircuitPersistence::None => { + info!("generating circuits..."); + ProverState { + state: self.circuit_config.as_all_recursive_circuits(), } - None => { - info!("failed to load preprocessed circuits from disk. generating circuits..."); - let all_recursive_circuits = circuit_config.as_all_recursive_circuits(); - info!("saving preprocessed circuits to disk"); - persistence::to_disk(&all_recursive_circuits, &circuit_config); - ProverState { - state: all_recursive_circuits, + } + CircuitPersistence::Disk(strategy) => { + info!("attempting to load preprocessed circuits from disk..."); + + let disk_state = match strategy { + TableLoadStrategy::OnDemand => BaseProverResource::get(&self.circuit_config), + TableLoadStrategy::Monolithic => { + MonolithicProverResource::get(&self.circuit_config) + } + }; + + match disk_state { + Ok(circuits) => { + info!("successfully loaded preprocessed circuits from disk"); + ProverState { state: circuits } + } + Err(_) => { + info!("failed to load preprocessed circuits from disk. generating circuits..."); + let all_recursive_circuits = + self.circuit_config.as_all_recursive_circuits(); + info!("saving preprocessed circuits to disk"); + persistence::persist_all_to_disk( + &all_recursive_circuits, + &self.circuit_config, + )?; + ProverState { + state: all_recursive_circuits, + } } } } - } - }; + }; - P_STATE.set(state) -} + P_STATE.set(state).map_err(|_| { + anyhow::Error::msg( + "prover state already set. check the program logic to ensure it is only set once", + ) + .context("setting prover state") + })?; + + MANAGER.set(self.clone()).map_err(|_| { + anyhow::Error::msg( + "prover state manager already set. check the program logic to ensure it is only set once", + ) + .context("setting prover state manager") + })?; -/// Loads a verifier state from disk or generate it. -pub fn get_verifier_state_from_config( - ProverStateConfig { - circuit_config, - persistence, - }: ProverStateConfig, -) -> VerifierState { - info!("initializing verifier state..."); - match persistence { - CircuitPersistence::None => { - info!("generating circuit..."); - let prover_state = circuit_config.as_all_recursive_circuits(); - VerifierState { - state: prover_state.final_verifier_data(), + Ok(()) + } + + /// Loads a verifier state from disk or generate it. + pub fn verifier(&self) -> anyhow::Result { + info!("initializing verifier state..."); + match self.persistence { + CircuitPersistence::None => { + info!("generating circuit..."); + let prover_state = self.circuit_config.as_all_recursive_circuits(); + Ok(VerifierState { + state: prover_state.final_verifier_data(), + }) } - } - CircuitPersistence::Disk => { - info!("attempting to load preprocessed verifier circuit from disk..."); - let disk_state = persistence::verifier_from_disk(&circuit_config); - match disk_state { - Some(state) => { - info!("successfully loaded preprocessed verifier circuit from disk"); - VerifierState { state } - } - None => { - info!( - "failed to load preprocessed verifier circuit from disk. generating it..." - ); - let prover_state = circuit_config.as_all_recursive_circuits(); + CircuitPersistence::Disk(_) => { + info!("attempting to load preprocessed verifier circuit from disk..."); + let disk_state = VerifierResource::get(&self.circuit_config); - info!("saving preprocessed verifier circuit to disk"); - let state = prover_state.final_verifier_data(); - persistence::verifier_to_disk(&state, &circuit_config); + match disk_state { + Ok(state) => { + info!("successfully loaded preprocessed verifier circuit from disk"); + Ok(VerifierState { state }) + } + Err(_) => { + info!("failed to load preprocessed verifier circuit from disk. generating it..."); + let prover_state = self.circuit_config.as_all_recursive_circuits(); + + info!("saving preprocessed verifier circuit to disk"); + let state = prover_state.final_verifier_data(); + VerifierResource::put(&self.circuit_config, &state)?; - VerifierState { state } + Ok(VerifierState { state }) + } } } } diff --git a/common/src/prover_state/persistence.rs b/common/src/prover_state/persistence.rs index c55f15af..d807df3a 100644 --- a/common/src/prover_state/persistence.rs +++ b/common/src/prover_state/persistence.rs @@ -1,126 +1,256 @@ use std::{ + fmt::{Debug, Display}, fs::{self, OpenOptions}, io::Write, + path::Path, }; -use plonky2::{ - plonk::config::PoseidonGoldilocksConfig, - util::serialization::{DefaultGateSerializer, DefaultGeneratorSerializer}, +use plonky2::util::serialization::{ + Buffer, DefaultGateSerializer, DefaultGeneratorSerializer, IoError, }; use proof_gen::types::{AllRecursiveCircuits, VerifierData}; -use tracing::{info, warn}; +use thiserror::Error; -use super::circuit::CircuitConfig; +use super::{ + circuit::{Circuit, CircuitConfig}, + Config, RecursiveCircuitsForTableSize, SIZE, +}; -type Config = PoseidonGoldilocksConfig; -const SIZE: usize = 2; const PROVER_STATE_FILE_PREFIX: &str = "./prover_state"; const VERIFIER_STATE_FILE_PREFIX: &str = "./verifier_state"; -fn get_serializers() -> (DefaultGateSerializer, DefaultGeneratorSerializer) { +fn get_serializers() -> ( + DefaultGateSerializer, + DefaultGeneratorSerializer, +) { let gate_serializer = DefaultGateSerializer; - let witness_serializer: DefaultGeneratorSerializer = DefaultGeneratorSerializer { - _phantom: Default::default(), - }; + let witness_serializer: DefaultGeneratorSerializer = + DefaultGeneratorSerializer::default(); (gate_serializer, witness_serializer) } -#[inline] -fn disk_path(circuit_config: &CircuitConfig, prefix: &str) -> String { - format!("{}_{}", prefix, circuit_config.get_configuration_digest()) +#[derive(Error, Debug)] +pub(crate) enum DiskResourceError { + #[error("Serialization error: {0}")] + Serialization(E), + #[error("IO error: {0}")] + IoError(#[from] std::io::Error), } -/// Loads [`AllRecursiveCircuits`] from disk. -pub fn prover_from_disk(circuit_config: &CircuitConfig) -> Option { - let path = disk_path(circuit_config, PROVER_STATE_FILE_PREFIX); - let bytes = fs::read(&path).ok()?; - info!("found prover state at {path}"); - let (gate_serializer, witness_serializer) = get_serializers(); - info!("deserializing prover state..."); - let state = - AllRecursiveCircuits::from_bytes(&bytes, false, &gate_serializer, &witness_serializer); - - match state { - Ok(state) => Some(state), - Err(e) => { - warn!("failed to deserialize prover state, {e:?}"); - None - } +/// A trait for generic resources that may be written to and read from disk, +/// each with their own serialization and deserialization logic. +pub(crate) trait DiskResource { + /// The type of error that may arise while serializing or deserializing the + /// resource. + type Error: Debug + Display; + /// The type of resource being serialized, deserialized, and written to + /// disk. + type Resource; + /// The input type / configuration used to generate a unique path to the + /// resource on disk. + type PathConstrutor; + + /// Returns the path to the resource on disk. + fn path(p: &Self::PathConstrutor) -> impl AsRef; + + /// Serializes the resource to bytes. + fn serialize(r: &Self::Resource) -> Result, DiskResourceError>; + + /// Deserializes the resource from bytes. + fn deserialize(bytes: &[u8]) -> Result>; + + /// Reads the resource from disk and deserializes it. + fn get(p: &Self::PathConstrutor) -> Result> { + Self::deserialize(&fs::read(Self::path(p))?) + } + + /// Writes the resource to disk after serializing it. + fn put( + p: &Self::PathConstrutor, + r: &Self::Resource, + ) -> Result<(), DiskResourceError> { + Ok(OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(Self::path(p))? + .write_all(&Self::serialize(r)?)?) } } -/// Loads [`VerifierData`] from disk. -pub fn verifier_from_disk(circuit_config: &CircuitConfig) -> Option { - let path = disk_path(circuit_config, VERIFIER_STATE_FILE_PREFIX); - let bytes = fs::read(&path).ok()?; - info!("found verifier state at {path}"); - let (gate_serializer, _witness_serializer) = get_serializers(); - info!("deserializing verifier state..."); - let state = VerifierData::from_bytes(bytes, &gate_serializer); - - match state { - Ok(state) => Some(state), - Err(e) => { - warn!("failed to deserialize verifier state, {e:?}"); - None - } +/// Pre-generated circuits containing just the three higher-level circuits. +/// These are sufficient for generating aggregation proofs and block +/// proofs, but not for transaction proofs. +#[derive(Debug, Default)] +pub(crate) struct BaseProverResource; + +impl DiskResource for BaseProverResource { + type Resource = AllRecursiveCircuits; + type Error = IoError; + type PathConstrutor = CircuitConfig; + + fn path(p: &Self::PathConstrutor) -> String { + format!( + "{}_base_{}", + PROVER_STATE_FILE_PREFIX, + p.get_configuration_digest() + ) + } + + fn serialize(r: &Self::Resource) -> Result, DiskResourceError> { + let (gate_serializer, witness_serializer) = get_serializers(); + + r + // Note we are using the `true` flag to write only the upper circuits. + // The individual circuit tables are written separately below. + .to_bytes(true, &gate_serializer, &witness_serializer) + .map_err(DiskResourceError::Serialization) + } + + fn deserialize(bytes: &[u8]) -> Result> { + let (gate_serializer, witness_serializer) = get_serializers(); + AllRecursiveCircuits::from_bytes(bytes, true, &gate_serializer, &witness_serializer) + .map_err(DiskResourceError::Serialization) } } -/// Writes the provided [`AllRecursiveCircuits`] to disk, along with the -/// associated [`VerifierData`], in two distinct files. -pub fn to_disk(circuits: &AllRecursiveCircuits, circuit_config: &CircuitConfig) { - prover_to_disk(circuits, circuit_config); - verifier_to_disk(&circuits.final_verifier_data(), circuit_config); +/// Pre-generated circuits containing all circuits. +#[derive(Debug, Default)] +pub(crate) struct MonolithicProverResource; + +impl DiskResource for MonolithicProverResource { + type Resource = AllRecursiveCircuits; + type Error = IoError; + type PathConstrutor = CircuitConfig; + + fn path(p: &Self::PathConstrutor) -> String { + format!( + "{}_monolithic_{}", + PROVER_STATE_FILE_PREFIX, + p.get_configuration_digest() + ) + } + + fn serialize(r: &Self::Resource) -> Result, DiskResourceError> { + let (gate_serializer, witness_serializer) = get_serializers(); + + r + // Note we are using the `false` flag to write all circuits. + .to_bytes(false, &gate_serializer, &witness_serializer) + .map_err(DiskResourceError::Serialization) + } + + fn deserialize(bytes: &[u8]) -> Result> { + let (gate_serializer, witness_serializer) = get_serializers(); + AllRecursiveCircuits::from_bytes(bytes, false, &gate_serializer, &witness_serializer) + .map_err(DiskResourceError::Serialization) + } } -/// Writes the provided [`AllRecursiveCircuits`] to disk. -fn prover_to_disk(circuits: &AllRecursiveCircuits, circuit_config: &CircuitConfig) { - let (gate_serializer, witness_serializer) = get_serializers(); - - // Write prover state to disk - if let Err(e) = circuits - .to_bytes(false, &gate_serializer, &witness_serializer) - .map(|bytes| { - write_bytes_to_file(&bytes, disk_path(circuit_config, PROVER_STATE_FILE_PREFIX)) - }) - { - warn!("failed to create prover state file, {e:?}"); - }; +/// An individual circuit table with a specific size. +#[derive(Debug, Default)] +pub(crate) struct RecursiveCircuitResource; + +impl DiskResource for RecursiveCircuitResource { + type Resource = RecursiveCircuitsForTableSize; + type Error = IoError; + type PathConstrutor = (Circuit, usize); + + fn path((circuit_type, size): &Self::PathConstrutor) -> String { + format!( + "{}_{}_{}", + PROVER_STATE_FILE_PREFIX, + circuit_type.as_short_str(), + size + ) + } + + fn serialize(r: &Self::Resource) -> Result, DiskResourceError> { + let (gate_serializer, witness_serializer) = get_serializers(); + let mut buf = Vec::new(); + + r.to_buffer(&mut buf, &gate_serializer, &witness_serializer) + .map_err(DiskResourceError::Serialization)?; + + Ok(buf) + } + + fn deserialize( + bytes: &[u8], + ) -> Result> { + let (gate_serializer, witness_serializer) = get_serializers(); + let mut buffer = Buffer::new(bytes); + RecursiveCircuitsForTableSize::from_buffer( + &mut buffer, + &gate_serializer, + &witness_serializer, + ) + .map_err(DiskResourceError::Serialization) + } } -/// Writes the provided [`VerifierData`] to disk. -pub fn verifier_to_disk(circuit: &VerifierData, circuit_config: &CircuitConfig) { - let (gate_serializer, _witness_serializer) = get_serializers(); +/// An individual circuit table with a specific size. +#[derive(Debug, Default)] +pub(crate) struct VerifierResource; + +impl DiskResource for VerifierResource { + type Resource = VerifierData; + type Error = IoError; + type PathConstrutor = CircuitConfig; - // Write verifier state to disk - if let Err(e) = circuit.to_bytes(&gate_serializer).map(|bytes| { - write_bytes_to_file( - &bytes, - disk_path(circuit_config, VERIFIER_STATE_FILE_PREFIX), + fn path(p: &Self::PathConstrutor) -> String { + format!( + "{}_{}", + VERIFIER_STATE_FILE_PREFIX, + p.get_configuration_digest() ) - }) { - warn!("failed to create verifier state file, {e:?}"); - }; + } + + fn serialize(r: &Self::Resource) -> Result, DiskResourceError> { + let (gate_serializer, _witness_serializer) = get_serializers(); + r.to_bytes(&gate_serializer) + .map_err(DiskResourceError::Serialization) + } + + fn deserialize(bytes: &[u8]) -> Result> { + let (gate_serializer, _) = get_serializers(); + VerifierData::from_bytes(bytes.to_vec(), &gate_serializer) + .map_err(DiskResourceError::Serialization) + } } -fn write_bytes_to_file(bytes: &[u8], path: String) { - let file = OpenOptions::new() - .write(true) - .create(true) - .truncate(true) - .open(path); - - let mut file = match file { - Ok(file) => file, - Err(e) => { - warn!("failed to create circuits file, {e:?}"); - return; - } - }; +/// Writes the provided [`AllRecursiveCircuits`] to disk with all +/// configurations, along with the associated [`VerifierData`]. +pub fn persist_all_to_disk( + circuits: &AllRecursiveCircuits, + circuit_config: &CircuitConfig, +) -> anyhow::Result<()> { + prover_to_disk(circuit_config, circuits)?; + VerifierResource::put(circuit_config, &circuits.final_verifier_data())?; - if let Err(e) = file.write_all(bytes) { - warn!("failed to write circuits file, {e:?}"); + Ok(()) +} + +/// Writes the provided [`AllRecursiveCircuits`] to disk. +/// +/// In particular, we cover both the monolothic and base prover states, as well +/// as the individual circuit tables. +fn prover_to_disk( + circuit_config: &CircuitConfig, + circuits: &AllRecursiveCircuits, +) -> Result<(), DiskResourceError> { + BaseProverResource::put(circuit_config, circuits)?; + MonolithicProverResource::put(circuit_config, circuits)?; + + // Write individual circuit tables to disk, by circuit type and size. This + // allows us to load only the necessary tables when needed. + for (circuit_type, tables) in circuits.by_table.iter().enumerate() { + let circuit_type: Circuit = circuit_type.into(); + for (size, table) in tables.by_stark_size.iter() { + RecursiveCircuitResource::put(&(circuit_type, *size), table)?; + } } + + Ok(()) } diff --git a/leader/src/main.rs b/leader/src/main.rs index 06590507..36d2f1e3 100644 --- a/leader/src/main.rs +++ b/leader/src/main.rs @@ -3,12 +3,11 @@ use std::{fs::File, path::PathBuf}; use anyhow::Result; use clap::Parser; use cli::Command; -use common::prover_state::set_prover_state_from_config; +use common::prover_state::TableLoadStrategy; use dotenvy::dotenv; use ops::register; use paladin::runtime::Runtime; use proof_gen::types::PlonkyProofIntern; -use tracing::warn; mod cli; mod http; @@ -37,11 +36,12 @@ async fn main() -> Result<()> { if let paladin::config::Runtime::InMemory = args.paladin.runtime { // If running in emulation mode, we'll need to initialize the prover // state here. - if set_prover_state_from_config(args.prover_state_config.into()).is_err() { - warn!( - "prover state already set. check the program logic to ensure it is only set once" - ); - } + args.prover_state_config + .into_prover_state_manager() + // Use the monolithic load strategy for the prover state when running in + // emulation mode. + .with_load_strategy(TableLoadStrategy::Monolithic) + .initialize()?; } let runtime = Runtime::from_config(&args.paladin, register()).await?; diff --git a/ops/src/lib.rs b/ops/src/lib.rs index 80fc2376..b51cff5f 100644 --- a/ops/src/lib.rs +++ b/ops/src/lib.rs @@ -1,20 +1,15 @@ -use common::prover_state::P_STATE; +use common::prover_state::{p_manager, p_state}; use paladin::{ - operation::{FatalError, Monoid, Operation, Result}, + operation::{FatalError, FatalStrategy, Monoid, Operation, Result}, registry, RemoteExecute, }; use proof_gen::{ - proof_gen::{generate_agg_proof, generate_block_proof, generate_txn_proof}, + proof_gen::{generate_agg_proof, generate_block_proof}, proof_types::{AggregatableProof, GeneratedAggProof, GeneratedBlockProof}, - prover_state::ProverState, }; use serde::{Deserialize, Serialize}; use trace_decoder::types::TxnProofGenIR; -fn p_state() -> &'static ProverState { - P_STATE.get().expect("Prover state is not initialized") -} - registry!(); #[derive(Deserialize, Serialize, RemoteExecute)] @@ -25,9 +20,11 @@ impl Operation for TxProof { type Output = AggregatableProof; fn execute(&self, input: Self::Input) -> Result { - let result = generate_txn_proof(p_state(), input, None).map_err(FatalError::from)?; + let proof = p_manager() + .generate_txn_proof(input) + .map_err(|err| FatalError::from_anyhow(err, FatalStrategy::Terminate))?; - Ok(result.into()) + Ok(proof.into()) } } diff --git a/verifier/src/main.rs b/verifier/src/main.rs index 0ff99442..60ba1d52 100644 --- a/verifier/src/main.rs +++ b/verifier/src/main.rs @@ -2,7 +2,6 @@ use std::fs::File; use anyhow::Result; use clap::Parser; -use common::prover_state::get_verifier_state_from_config; use proof_gen::types::PlonkyProofIntern; use serde_json::Deserializer; @@ -17,9 +16,12 @@ fn main() -> Result<()> { let des = &mut Deserializer::from_reader(&file); let input: PlonkyProofIntern = serde_path_to_error::deserialize(des)?; - let verifier_state = get_verifier_state_from_config(args.prover_state_config.into()); + let verifer = args + .prover_state_config + .into_prover_state_manager() + .verifier()?; - verifier_state.verify(&input)?; + verifer.verify(&input)?; Ok(()) } diff --git a/worker/src/main.rs b/worker/src/main.rs index 7048d591..d24ee23a 100644 --- a/worker/src/main.rs +++ b/worker/src/main.rs @@ -1,10 +1,9 @@ use anyhow::Result; use clap::Parser; -use common::prover_state::{cli::CliProverStateConfig, set_prover_state_from_config}; +use common::prover_state::cli::CliProverStateConfig; use dotenvy::dotenv; use ops::register; use paladin::runtime::WorkerRuntime; -use tracing::warn; mod init; @@ -22,9 +21,9 @@ async fn main() -> Result<()> { init::tracing(); let args = Cli::parse(); - if set_prover_state_from_config(args.prover_state_config.into()).is_err() { - warn!("prover state already set. check the program logic to ensure it is only set once"); - } + args.prover_state_config + .into_prover_state_manager() + .initialize()?; let runtime = WorkerRuntime::from_config(&args.paladin, register()).await?; runtime.main_loop().await?;