From 8f3a572076f4778b8c5ec56e4ec15bbd0ee61778 Mon Sep 17 00:00:00 2001 From: ShuangWu121 <47602565+ShuangWu121@users.noreply.github.com> Date: Mon, 18 Nov 2024 16:35:59 +0800 Subject: [PATCH] Stwo support "addition and equal test" (#1942) ## Initial Support for `stwo` Backend This PR introduces the initial support for the `stwo` backend in the Powdr pipeline. ### Summary - **Constraint Proving**: Added functionality to prove and verify constraints between columns. - **Test Added**: A test case has been added to validate the above functionality. ### Remaining Tasks - Implement the `setup` function. - Use `logup` for the next reference (now the next reference creates a new column). - Check the constant columns and public values in stwo and implement to Powdr. ### How to Test To test the `stwo` backend, use the following command: ```bash cargo test --features stwo --package powdr-pipeline --test pil -- stwo_add_and_equal --exact --show-output --------- Co-authored-by: Shuang Wu Co-authored-by: Shuang Wu Co-authored-by: Thibaut Schaeffer --- backend/Cargo.toml | 2 +- backend/src/lib.rs | 10 +- backend/src/stwo/circuit_builder.rs | 193 ++++++++++++++++++++++++++++ backend/src/stwo/mod.rs | 40 ++++-- backend/src/stwo/prover.rs | 133 +++++++++++++++++-- pipeline/src/test_util.rs | 26 ++++ pipeline/tests/pil.rs | 7 +- test_data/pil/add_and_equal.pil | 10 ++ 8 files changed, 390 insertions(+), 31 deletions(-) create mode 100644 backend/src/stwo/circuit_builder.rs create mode 100644 test_data/pil/add_and_equal.pil diff --git a/backend/Cargo.toml b/backend/Cargo.toml index 6b2c56ba31..4210387c9a 100644 --- a/backend/Cargo.toml +++ b/backend/Cargo.toml @@ -62,7 +62,7 @@ p3-commit = { git = "https://github.com/plonky3/Plonky3.git", rev = "2192432ddf2 p3-matrix = { git = "https://github.com/plonky3/Plonky3.git", rev = "2192432ddf28e7359dd2c577447886463e6124f0", optional = true } p3-uni-stark = { git = "https://github.com/plonky3/Plonky3.git", rev = "2192432ddf28e7359dd2c577447886463e6124f0", optional = true } # TODO: Change this to main branch when the `andrew/dev/update-toolchain` branch is merged,the main branch is using "nightly-2024-01-04", not compatiable with plonky3 -stwo-prover = { git = "https://github.com/starkware-libs/stwo.git", optional = true, rev = "52d050c18b5dbc74af40214b3b441a6f60a20d41" } +stwo-prover = { git = "https://github.com/starkware-libs/stwo.git", optional = true, rev = "e6d10bc107c11cce54bb4aa152c3afa2e15e92c1" } strum = { version = "0.24.1", features = ["derive"] } log = "0.4.17" diff --git a/backend/src/lib.rs b/backend/src/lib.rs index cc06cb80b3..12cd783272 100644 --- a/backend/src/lib.rs +++ b/backend/src/lib.rs @@ -107,15 +107,7 @@ impl BackendType { Box::new(composite::CompositeBackendFactory::new(plonky3::Factory)) } #[cfg(feature = "stwo")] - BackendType::Stwo => Box::new(stwo::StwoProverFactory), - #[cfg(not(any( - feature = "halo2", - feature = "estark-polygon", - feature = "estark-starky", - feature = "plonky3", - feature = "stwo" - )))] - _ => panic!("Empty backend."), + BackendType::Stwo => Box::new(stwo::Factory), } } } diff --git a/backend/src/stwo/circuit_builder.rs b/backend/src/stwo/circuit_builder.rs new file mode 100644 index 0000000000..3b72ecf0b6 --- /dev/null +++ b/backend/src/stwo/circuit_builder.rs @@ -0,0 +1,193 @@ +use num_traits::Zero; +use std::fmt::Debug; +use std::ops::{Add, AddAssign, Mul, Neg, Sub}; + +extern crate alloc; +use alloc::{collections::btree_map::BTreeMap, string::String, vec::Vec}; +use powdr_ast::analyzed::{ + AlgebraicBinaryOperation, AlgebraicBinaryOperator, AlgebraicExpression, Analyzed, Identity, +}; +use powdr_number::{FieldElement, LargeInt}; +use std::sync::Arc; + +use powdr_ast::analyzed::{ + AlgebraicUnaryOperation, AlgebraicUnaryOperator, PolyID, PolynomialType, +}; +use stwo_prover::constraint_framework::{EvalAtRow, FrameworkComponent, FrameworkEval}; +use stwo_prover::core::backend::ColumnOps; +use stwo_prover::core::fields::m31::{BaseField, M31}; +use stwo_prover::core::fields::{ExtensionOf, FieldExpOps, FieldOps}; +use stwo_prover::core::poly::circle::{CanonicCoset, CircleEvaluation}; +use stwo_prover::core::poly::BitReversedOrder; +use stwo_prover::core::ColumnVec; + +pub type PowdrComponent<'a, F> = FrameworkComponent>; + +pub(crate) fn gen_stwo_circuit_trace( + witness: &[(String, Vec)], +) -> ColumnVec> +where + T: FieldElement, //only Merenne31Field is supported, checked in runtime + B: FieldOps + ColumnOps, // Ensure B implements FieldOps for M31 + F: ExtensionOf, +{ + assert!( + witness + .iter() + .all(|(_name, vec)| vec.len() == witness[0].1.len()), + "All Vec in witness must have the same length. Mismatch found!" + ); + let domain = CanonicCoset::new(witness[0].1.len().ilog2()).circle_domain(); + witness + .iter() + .map(|(_name, values)| { + let values = values + .iter() + .map(|v| v.try_into_i32().unwrap().into()) + .collect(); + CircleEvaluation::new(domain, values) + }) + .collect() +} + +pub struct PowdrEval { + analyzed: Arc>, + witness_columns: BTreeMap, +} + +impl PowdrEval { + pub fn new(analyzed: Arc>) -> Self { + let witness_columns: BTreeMap = analyzed + .definitions_in_source_order(PolynomialType::Committed) + .flat_map(|(symbol, _)| symbol.array_elements()) + .enumerate() + .map(|(index, (_, id))| (id, index)) + .collect(); + + Self { + analyzed, + witness_columns, + } + } +} + +impl FrameworkEval for PowdrEval { + fn log_size(&self) -> u32 { + self.analyzed.degree().ilog2() + } + fn max_constraint_log_degree_bound(&self) -> u32 { + self.analyzed.degree().ilog2() + 1 + } + fn evaluate(&self, mut eval: E) -> E { + assert!( + self.analyzed.constant_count() == 0 && self.analyzed.publics_count() == 0, + "Error: Expected no fixed columns nor public inputs, as they are not supported yet.", + ); + + let witness_eval: BTreeMap::F; 2]> = self + .witness_columns + .keys() + .map(|poly_id| (*poly_id, eval.next_interaction_mask(0, [0, 1]))) + .collect(); + + for id in self + .analyzed + .identities_with_inlined_intermediate_polynomials() + { + match id { + Identity::Polynomial(identity) => { + let expr = to_stwo_expression(&identity.expression, &witness_eval); + eval.add_constraint(expr); + } + Identity::Connect(..) => { + unimplemented!("Connect is not implemented in stwo yet") + } + Identity::Lookup(..) => { + unimplemented!("Lookup is not implemented in stwo yet") + } + Identity::Permutation(..) => { + unimplemented!("Permutation is not implemented in stwo yet") + } + Identity::PhantomPermutation(..) => {} + Identity::PhantomLookup(..) => {} + } + } + eval + } +} + +fn to_stwo_expression( + expr: &AlgebraicExpression, + witness_eval: &BTreeMap, +) -> F +where + F: FieldExpOps + + Clone + + Debug + + Zero + + Neg + + AddAssign + + AddAssign + + Add + + Sub + + Mul + + Neg + + From, +{ + use AlgebraicBinaryOperator::*; + match expr { + AlgebraicExpression::Reference(r) => { + let poly_id = r.poly_id; + + match poly_id.ptype { + PolynomialType::Committed => match r.next { + false => witness_eval[&poly_id][0].clone(), + true => witness_eval[&poly_id][1].clone(), + }, + PolynomialType::Constant => { + unimplemented!("Constant polynomials are not supported in stwo yet") + } + PolynomialType::Intermediate => { + unimplemented!("Intermediate polynomials are not supported in stwo yet") + } + } + } + AlgebraicExpression::PublicReference(..) => { + unimplemented!("Public references are not supported in stwo yet") + } + AlgebraicExpression::Number(n) => F::from(M31::from(n.try_into_i32().unwrap())), + AlgebraicExpression::BinaryOperation(AlgebraicBinaryOperation { + left, + op: Pow, + right, + }) => match **right { + AlgebraicExpression::Number(n) => { + let left = to_stwo_expression(left, witness_eval); + (0u32..n.to_integer().try_into_u32().unwrap()) + .fold(F::one(), |acc, _| acc * left.clone()) + } + _ => unimplemented!("pow with non-constant exponent"), + }, + AlgebraicExpression::BinaryOperation(AlgebraicBinaryOperation { left, op, right }) => { + let left = to_stwo_expression(left, witness_eval); + let right = to_stwo_expression(right, witness_eval); + + match op { + Add => left + right, + Sub => left - right, + Mul => left * right, + Pow => unreachable!("This case was handled above"), + } + } + AlgebraicExpression::UnaryOperation(AlgebraicUnaryOperation { op, expr }) => { + let expr = to_stwo_expression(expr, witness_eval); + + match op { + AlgebraicUnaryOperator::Minus => -expr, + } + } + AlgebraicExpression::Challenge(_challenge) => { + unimplemented!("challenges are not supported in stwo yet") + } + } +} diff --git a/backend/src/stwo/mod.rs b/backend/src/stwo/mod.rs index a9a0e92105..794b66401d 100644 --- a/backend/src/stwo/mod.rs +++ b/backend/src/stwo/mod.rs @@ -1,20 +1,28 @@ +use serde::de::DeserializeOwned; +use serde::Serialize; use std::io; use std::path::PathBuf; use std::sync::Arc; -use crate::{Backend, BackendFactory, BackendOptions, Error, Proof}; +use crate::{ + field_filter::generalize_factory, Backend, BackendFactory, BackendOptions, Error, Proof, +}; use powdr_ast::analyzed::Analyzed; use powdr_executor::constant_evaluator::{get_uniquely_sized_cloned, VariablySizedColumn}; use powdr_executor::witgen::WitgenCallback; -use powdr_number::FieldElement; +use powdr_number::{FieldElement, Mersenne31Field}; use prover::StwoProver; +use stwo_prover::core::backend::{simd::SimdBackend, BackendForChannel}; +use stwo_prover::core::channel::{Blake2sChannel, Channel, MerkleChannel}; +use stwo_prover::core::vcs::blake2_merkle::Blake2sMerkleChannel; +mod circuit_builder; mod prover; - #[allow(dead_code)] -pub(crate) struct StwoProverFactory; -impl BackendFactory for StwoProverFactory { +struct RestrictedFactory; + +impl BackendFactory for RestrictedFactory { #[allow(unreachable_code)] #[allow(unused_variables)] fn create( @@ -37,16 +45,28 @@ impl BackendFactory for StwoProverFactory { let fixed = Arc::new( get_uniquely_sized_cloned(&fixed).map_err(|_| Error::NoVariableDegreeAvailable)?, ); - let stwo = Box::new(StwoProver::new(pil, fixed, setup)?); + let stwo: Box> = + Box::new(StwoProver::new(pil, fixed)?); Ok(stwo) } } -impl Backend for StwoProver { +generalize_factory!(Factory <- RestrictedFactory, [Mersenne31Field]); + +impl Backend + for StwoProver +where + SimdBackend: BackendForChannel, + MC: MerkleChannel, + C: Channel, + MC::H: DeserializeOwned + Serialize, +{ #[allow(unused_variables)] fn verify(&self, proof: &[u8], instances: &[Vec]) -> Result<(), Error> { - assert!(instances.len() == 1); - unimplemented!() + assert_eq!(instances.len(), 1); + let instances = &instances[0]; + + Ok(self.verify(proof, instances)?) } #[allow(unreachable_code)] #[allow(unused_variables)] @@ -59,7 +79,7 @@ impl Backend for StwoProver { if prev_proof.is_some() { return Err(Error::NoAggregationAvailable); } - unimplemented!() + Ok(StwoProver::prove(self, witness)?) } #[allow(unused_variables)] fn export_verification_key(&self, output: &mut dyn io::Write) -> Result<(), Error> { diff --git a/backend/src/stwo/prover.rs b/backend/src/stwo/prover.rs index 35f66e0ca9..ab79e93b03 100644 --- a/backend/src/stwo/prover.rs +++ b/backend/src/stwo/prover.rs @@ -1,32 +1,145 @@ use powdr_ast::analyzed::Analyzed; +use serde::de::DeserializeOwned; +use serde::ser::Serialize; use std::io; +use std::marker::PhantomData; use std::sync::Arc; +use crate::stwo::circuit_builder::{gen_stwo_circuit_trace, PowdrComponent, PowdrEval}; + +use stwo_prover::constraint_framework::TraceLocationAllocator; +use stwo_prover::core::prover::StarkProof; + use powdr_number::FieldElement; +use stwo_prover::core::air::{Component, ComponentProver}; +use stwo_prover::core::backend::{Backend, BackendForChannel}; +use stwo_prover::core::channel::{Channel, MerkleChannel}; +use stwo_prover::core::fields::m31::M31; +use stwo_prover::core::fri::FriConfig; +use stwo_prover::core::pcs::{CommitmentSchemeProver, CommitmentSchemeVerifier, PcsConfig}; +use stwo_prover::core::poly::circle::CanonicCoset; + +const FRI_LOG_BLOWUP: usize = 1; +const FRI_NUM_QUERIES: usize = 100; +const FRI_PROOF_OF_WORK_BITS: usize = 16; +const LOG_LAST_LAYER_DEGREE_BOUND: usize = 0; -#[allow(unused_variables)] -pub struct StwoProver { - _analyzed: Arc>, - _fixed: Arc)>>, +pub struct StwoProver { + pub analyzed: Arc>, + _fixed: Arc)>>, /// Proving key placeholder _proving_key: Option<()>, /// Verifying key placeholder _verifying_key: Option<()>, + _channel_marker: PhantomData, + _backend_marker: PhantomData, + _merkle_channel_marker: PhantomData, } -impl StwoProver { - #[allow(dead_code)] - #[allow(unused_variables)] +impl<'a, F: FieldElement, B, MC, C> StwoProver +where + B: Backend + Send + BackendForChannel, // Ensure B implements BackendForChannel + MC: MerkleChannel + Send, + C: Channel + Send, + MC::H: DeserializeOwned + Serialize, + PowdrComponent<'a, F>: ComponentProver, +{ pub fn new( - _analyzed: Arc>, + analyzed: Arc>, _fixed: Arc)>>, - setup: Option<&mut dyn io::Read>, ) -> Result { Ok(Self { - _analyzed, + analyzed, _fixed, _proving_key: None, _verifying_key: None, + _channel_marker: PhantomData, + _backend_marker: PhantomData, + _merkle_channel_marker: PhantomData, }) } + pub fn prove(&self, witness: &[(String, Vec)]) -> Result, String> { + let config = get_config(); + // twiddles are used for FFT, they are computed in a bigger group than the eval domain. + // the eval domain is the half coset G_{2n} + + // twiddles are computed in half coset G_{4n} + , double the size of eval doamin. + let twiddles = B::precompute_twiddles( + CanonicCoset::new(self.analyzed.degree().ilog2() + 1 + FRI_LOG_BLOWUP as u32) + .circle_domain() + .half_coset, + ); + + // Setup protocol. + let mut prover_channel = ::C::default(); + let commitment_scheme = &mut CommitmentSchemeProver::::new(config, &twiddles); + + let trace = gen_stwo_circuit_trace::(witness); + + let mut tree_builder = commitment_scheme.tree_builder(); + tree_builder.extend_evals(trace); + tree_builder.commit(&mut prover_channel); + + let component = PowdrComponent::new( + &mut TraceLocationAllocator::default(), + PowdrEval::new(self.analyzed.clone()), + ); + + let proof = stwo_prover::core::prover::prove::( + &[&component], + &mut prover_channel, + commitment_scheme, + ) + .unwrap(); + + Ok(bincode::serialize(&proof).unwrap()) + } + + pub fn verify(&self, proof: &[u8], _instances: &[F]) -> Result<(), String> { + assert!( + _instances.is_empty(), + "Expected _instances slice to be empty, but it has {} elements.", + _instances.len() + ); + + let config = get_config(); + let proof: StarkProof = + bincode::deserialize(proof).map_err(|e| format!("Failed to deserialize proof: {e}"))?; + + let mut verifier_channel = ::C::default(); + let mut commitment_scheme = CommitmentSchemeVerifier::::new(config); + + //Constraints that are to be proved + let component = PowdrComponent::new( + &mut TraceLocationAllocator::default(), + PowdrEval::new(self.analyzed.clone()), + ); + + // Retrieve the expected column sizes in each commitment interaction, from the AIR. + // TODO: When constant columns are supported, there will be more than one sizes and proof.commitments + // size[0] is for constant columns, size[1] is for witness columns, size[2] is for lookup columns + // pass size[1] for witness columns now is not doable due to this branch is outdated for the new feature of constant columns + // it will throw errors. + let sizes = component.trace_log_degree_bounds(); + assert_eq!(sizes.len(), 1); + commitment_scheme.commit(proof.commitments[0], &sizes[0], &mut verifier_channel); + + stwo_prover::core::prover::verify( + &[&component], + &mut verifier_channel, + &mut commitment_scheme, + proof, + ) + .map_err(|e| e.to_string()) + } +} + +fn get_config() -> PcsConfig { + PcsConfig { + pow_bits: FRI_PROOF_OF_WORK_BITS as u32, + fri_config: FriConfig::new( + LOG_LAST_LAYER_DEGREE_BOUND as u32, + FRI_LOG_BLOWUP as u32, + FRI_NUM_QUERIES, + ), + } } diff --git a/pipeline/src/test_util.rs b/pipeline/src/test_util.rs index 9dc0a86e63..63b0c4a13c 100644 --- a/pipeline/src/test_util.rs +++ b/pipeline/src/test_util.rs @@ -541,3 +541,29 @@ pub fn run_reparse_test_with_blacklist(file: &str, blacklist: &[&str]) { .compute_optimized_pil() .unwrap(); } + +#[cfg(feature = "stwo")] +use powdr_number::Mersenne31Field; +#[cfg(feature = "stwo")] +pub fn test_stwo(file_name: &str, inputs: Vec) { + let backend = powdr_backend::BackendType::Stwo; + + let mut pipeline = Pipeline::default() + .with_tmp_output() + .from_file(resolve_test_file(file_name)) + .with_prover_inputs(inputs) + .with_backend(backend, None); + + let proof = pipeline.compute_proof().cloned().unwrap(); + let publics: Vec = pipeline + .publics() + .clone() + .unwrap() + .iter() + .map(|(_name, v)| v.expect("all publics should be known since we created a proof")) + .collect(); + pipeline.verify(&proof, &[publics]).unwrap(); +} + +#[cfg(not(feature = "stwo"))] +pub fn test_stwo(_file_name: &str, _inputs: Vec) {} diff --git a/pipeline/tests/pil.rs b/pipeline/tests/pil.rs index f7ddfc7008..b4385614d9 100644 --- a/pipeline/tests/pil.rs +++ b/pipeline/tests/pil.rs @@ -8,7 +8,7 @@ use powdr_pipeline::{ assert_proofs_fail_for_invalid_witnesses_pilcom, gen_estark_proof, gen_estark_proof_with_backend_variant, make_prepared_pipeline, make_simple_prepared_pipeline, regular_test, run_pilcom_with_backend_variant, test_halo2, - test_halo2_with_backend_variant, test_pilcom, test_plonky3_with_backend_variant, + test_halo2_with_backend_variant, test_pilcom, test_plonky3_with_backend_variant, test_stwo, BackendVariant, }, Pipeline, @@ -264,6 +264,11 @@ fn add() { ); } +#[test] +fn stwo_add_and_equal() { + let f = "pil/add_and_equal.pil"; + test_stwo(f, Default::default()); +} #[test] fn simple_div() { let f = "pil/simple_div.pil"; diff --git a/test_data/pil/add_and_equal.pil b/test_data/pil/add_and_equal.pil new file mode 100644 index 0000000000..e765218354 --- /dev/null +++ b/test_data/pil/add_and_equal.pil @@ -0,0 +1,10 @@ +namespace Add(4); + + let a; + std::prelude::set_hint(a, |i| if i==0||i==2 {std::prelude::Query::Hint(2147483646)}else{std::prelude::Query::Hint(1)}); + let b; + let c; + b = c; + a+a'=0; + + \ No newline at end of file