Skip to content

Commit

Permalink
plonk: benches: collaborative-proof: Add benchmark with stats collection
Browse files Browse the repository at this point in the history
  • Loading branch information
joeykraut committed Dec 4, 2023
1 parent dee7e22 commit 80ac26f
Show file tree
Hide file tree
Showing 7 changed files with 291 additions and 107 deletions.
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,6 @@ repository = "https://github.com/EspressoSystems/jellyfish"
[workspace.dependencies]
itertools = { version = "0.10.1", default-features = false }
ark-mpc = { git = "https://github.com/renegade-fi/ark-mpc" }

[profile.bench]
debug = true
12 changes: 10 additions & 2 deletions plonk/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ license = { workspace = true }
rust-version = { workspace = true }

[dependencies]
ark-bn254 = { version = "0.4.0", optional = true }
ark-ec = "0.4.0"
ark-ff = { version = "0.4.0", features = ["asm"] }
ark-mpc = { workspace = true }
Expand All @@ -30,6 +31,7 @@ mpc-relation = { path = "../relation", default-features = false }
jf-utils = { path = "../utilities" }
merlin = { version = "3.0.0", default-features = false }
num-bigint = { version = "0.4", default-features = false }
rand = { version = "0.8", optional = true }
rand_chacha = { version = "0.3.1", default-features = false }
rayon = { version = "1.5.0", optional = true }
serde = { version = "1.0", default-features = false, features = ["derive"] }
Expand All @@ -44,19 +46,25 @@ ark-bw6-761 = "0.4.0"
ark-ed-on-bls12-377 = "0.4.0"
ark-ed-on-bls12-381 = "0.4.0"
ark-ed-on-bn254 = "0.4.0"
criterion = { version = "0.5", features = ["async", "async_tokio"] }
hex = "^0.4.3"
tokio = "1.33"
rand = "0.8"

# Benchmarks
[[bench]]
name = "plonk-benches"
path = "benches/bench.rs"
harness = false

[[bench]]
name = "collaborative_proof"
harness = false
required-features = ["test_apis", "test-srs"]

[features]
all-tests = ["ark-mpc/test_helpers", "jf-primitives/test-srs"]
default = ["parallel", "std"]
stats = ["ark-mpc/stats"]
std = [
"ark-std/std",
"ark-serialize/std",
Expand All @@ -72,7 +80,7 @@ std = [
"rand_chacha/std",
"sha3/std",
]
test_apis = [] # exposing apis for testing purpose
test_apis = ["ark-bn254", "rand", "ark-mpc/test_helpers"]
parallel = [
"ark-ff/parallel",
"ark-ec/parallel",
Expand Down
149 changes: 149 additions & 0 deletions plonk/benches/collaborative_proof.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
//! Benchmarks a collaborative proof and a singleprover proof on the same
//! circuit for baselining
use std::time::{Duration, Instant};

use ark_mpc::test_helpers::execute_mock_mpc;
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
use mpc_plonk::{
multiprover::proof_system::{
test_helpers::{gen_multiprover_circuit_for_test, setup_snark, TestCurve},
MultiproverPlonkKzgSnark,
},
proof_system::{snark_test_helpers::gen_circuit_for_test, structs::ProvingKey, PlonkKzgSnark},
transcript::SolidityTranscript,
};
use mpc_relation::PlonkType;
use rand::thread_rng;
use tokio::runtime::Builder as RuntimeBuilder;

const CIRCUIT_SIZING_PARAM: usize = 100;

// -----------
// | Helpers |
// -----------

/// Setup a proving key for the benchmark circuit
fn setup_pk() -> ProvingKey<TestCurve> {
// Build a circuit and setup the proving key
let circuit = gen_circuit_for_test(
CIRCUIT_SIZING_PARAM,
0, // start_val
PlonkType::TurboPlonk,
)
.unwrap();
let (pk, _) = setup_snark(&circuit);

pk
}

// --------------
// | Benchmarks |
// --------------

/// Benchmark a singleprover proof on a test circuit
fn bench_singleprover(c: &mut Criterion) {
// Build a circuit to prove satisfaction for
let pk = setup_pk();
let circuit = gen_circuit_for_test(
CIRCUIT_SIZING_PARAM,
0, // start_val
PlonkType::TurboPlonk,
)
.unwrap();

let mut group = c.benchmark_group("singleprover");
let id = BenchmarkId::new("prover-latency", CIRCUIT_SIZING_PARAM);
group.bench_function(id, |b| {
b.iter(|| {
let mut rng = thread_rng();
let res = PlonkKzgSnark::batch_prove::<_, _, SolidityTranscript>(
&mut rng,
&[&circuit],
&[&pk],
)
.unwrap();
black_box(res);
})
});
}

/// Benchmark a collaborative proof on a test circuit
fn bench_multiprover(c: &mut Criterion) {
let runtime = RuntimeBuilder::new_multi_thread()
.worker_threads(3)
.enable_all()
.build()
.unwrap();

let pk = setup_pk();

let mut group = c.benchmark_group("multiprover");
let id = BenchmarkId::new("prover-latency", CIRCUIT_SIZING_PARAM);
group.bench_function(id, |b| {
let mut b = b.to_async(&runtime);
b.iter_custom(|n_iters| {
let pk = pk.clone();
async move {
let mut total_time = Duration::from_millis(0);
for _ in 0..n_iters {
let elapsed = multiprover_prove(&pk).await;
total_time += elapsed;
}

total_time
}
})
});
}

/// Prove the test circuit in a multiprover setting using the given key
///
/// Return the latency excluding the MPC setup time
async fn multiprover_prove(pk: &ProvingKey<TestCurve>) -> Duration {
let (elapsed1, elapsed2) = execute_mock_mpc(|fabric| {
let pk = pk.clone();
async move {
let start = Instant::now();
let circuit = gen_multiprover_circuit_for_test(
CIRCUIT_SIZING_PARAM,
0, // start val
fabric.clone(),
)
.unwrap();

black_box(
MultiproverPlonkKzgSnark::prove(&circuit, &pk, fabric)
.unwrap()
.open_authenticated()
.await
.unwrap(),
);

start.elapsed()
}
})
.await;

Duration::max(elapsed1, elapsed2)
}

criterion_group! {
name = prover_latency;
config = Criterion::default().sample_size(10);
targets = bench_singleprover, bench_multiprover,
}

#[cfg(not(feature = "stats"))]
criterion_main!(prover_latency);

#[cfg(feature = "stats")]
#[tokio::main]
async fn main() {
let pk = setup_pk();
let duration = multiprover_prove(&pk).await;

// Let the fabric print its stats
tokio::time::sleep(Duration::from_secs(1)).await;
println!("\nTook: {duration:?}");
}
32 changes: 5 additions & 27 deletions plonk/src/multiprover/proof_system/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -808,8 +808,6 @@ pub fn mul_poly_result<C: CurveGroup>(

#[cfg(test)]
pub(crate) mod test {

use ark_ec::pairing::Pairing;
use ark_ff::{One, Zero};
use ark_mpc::{
algebra::{AuthenticatedDensePoly, Scalar},
Expand All @@ -825,38 +823,18 @@ pub(crate) mod test {
use rand::thread_rng;

use crate::{
multiprover::proof_system::{MpcChallenges, MpcOracles, MpcPlonkCircuit},
multiprover::proof_system::{
test_helpers::{setup_snark, TestCurve, TestGroup, TestScalar},
MpcChallenges, MpcOracles, MpcPlonkCircuit,
},
proof_system::{
prover::Prover,
structs::{Challenges, Oracles, ProofEvaluations, ProvingKey, VerifyingKey},
PlonkKzgSnark, UniversalSNARK,
structs::{Challenges, Oracles, ProofEvaluations, ProvingKey},
},
};

use super::MpcProver;

/// The curve used for testing
pub type TestCurve = ark_bn254::Bn254;
/// The curve group used for testing
pub type TestGroup = <TestCurve as Pairing>::G1;
/// The scalar field of the test curve
pub type TestScalar = <TestCurve as Pairing>::ScalarField;

/// The max degree of the circuits used for testing
pub const MAX_DEGREE_TESTING: usize = 1024;

/// Setup commitment keys, proving and verification keys for the snark
pub(crate) fn setup_snark<C: Arithmetization<TestScalar>>(
circuit: &C,
) -> (ProvingKey<TestCurve>, VerifyingKey<TestCurve>) {
let mut rng = thread_rng();
let srs =
PlonkKzgSnark::<TestCurve>::universal_setup_for_testing(MAX_DEGREE_TESTING, &mut rng)
.unwrap();

PlonkKzgSnark::<TestCurve>::preprocess(&srs, circuit).unwrap()
}

/// Get a randomized set of challenges
fn randomized_challenges() -> Challenges<TestScalar> {
let mut rng = thread_rng();
Expand Down
76 changes: 55 additions & 21 deletions plonk/src/multiprover/proof_system/snark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,37 +192,48 @@ impl<P: SWCurveConfig<BaseField = E::BaseField>, E: Pairing<G1Affine = Affine<P>
}
}

#[cfg(test)]
mod tests {
use ark_mpc::{
algebra::{AuthenticatedScalarResult, Scalar},
test_helpers::execute_mock_mpc,
MpcFabric,
};
use futures::future::join_all;
use itertools::Itertools;
use mpc_relation::{traits::*, PlonkType};
use rand::{thread_rng, Rng};
#[cfg(any(test, feature = "test_apis"))]
pub mod test_helpers {
//! Test helpers for the multiprover snark
use ark_ec::pairing::Pairing;
use ark_mpc::{algebra::Scalar, MpcFabric};
use mpc_relation::traits::*;
use rand::thread_rng;

use crate::{
errors::PlonkError,
multiprover::proof_system::{
test::{setup_snark, test_multiprover_circuit, test_singleprover_circuit, TestGroup},
MpcPlonkCircuit,
multiprover::proof_system::MpcPlonkCircuit,
proof_system::{
structs::{ProvingKey, VerifyingKey},
PlonkKzgSnark, UniversalSNARK,
},
proof_system::{snark::test::gen_circuit_for_test, PlonkKzgSnark},
transcript::SolidityTranscript,
};

use super::MultiproverPlonkKzgSnark;
/// The curve used for testing
pub type TestCurve = ark_bn254::Bn254;
/// The curve group used for testing
pub type TestGroup = <TestCurve as Pairing>::G1;
/// The scalar field of the test curve
pub type TestScalar = <TestCurve as Pairing>::ScalarField;

// -----------
// | Helpers |
// -----------
/// The max degree of the circuits used for testing
pub const MAX_DEGREE_TESTING: usize = 1024;

/// Setup commitment keys, proving and verification keys for the snark
pub fn setup_snark<C: Arithmetization<TestScalar>>(
circuit: &C,
) -> (ProvingKey<TestCurve>, VerifyingKey<TestCurve>) {
let mut rng = thread_rng();
let srs =
PlonkKzgSnark::<TestCurve>::universal_setup_for_testing(MAX_DEGREE_TESTING, &mut rng)
.unwrap();

PlonkKzgSnark::<TestCurve>::preprocess(&srs, circuit).unwrap()
}

/// A multiprover analog of the circuit used for testing the single-prover
/// implementation in `plonk/proof_system/snark.rs`
pub(crate) fn gen_multiprover_circuit_for_test(
pub fn gen_multiprover_circuit_for_test(
m: usize,
a0: usize,
fabric: MpcFabric<TestGroup>,
Expand Down Expand Up @@ -268,6 +279,29 @@ mod tests {

Ok(cs)
}
}

#[cfg(test)]
mod tests {
use ark_mpc::{
algebra::{AuthenticatedScalarResult, Scalar},
test_helpers::execute_mock_mpc,
};
use futures::future::join_all;
use itertools::Itertools;
use mpc_relation::{traits::*, PlonkType};
use rand::{thread_rng, Rng};

use crate::{
multiprover::proof_system::{
test::{test_multiprover_circuit, test_singleprover_circuit},
test_helpers::{gen_multiprover_circuit_for_test, setup_snark},
},
proof_system::{snark::test_helpers::gen_circuit_for_test, PlonkKzgSnark},
transcript::SolidityTranscript,
};

use super::MultiproverPlonkKzgSnark;

// ---------
// | Tests |
Expand Down
3 changes: 3 additions & 0 deletions plonk/src/proof_system/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ pub mod verifier;
use crate::transcript::PlonkTranscript;
pub use snark::PlonkKzgSnark;

#[cfg(feature = "test_apis")]
pub use snark::test_helpers as snark_test_helpers;

// TODO: (alex) should we name it `PlonkishSNARK` instead? since we use
// `PlonkTranscript` on prove and verify.
/// An interface for SNARKs with universal setup.
Expand Down
Loading

0 comments on commit 80ac26f

Please sign in to comment.