Skip to content

Commit

Permalink
use Borrow to avoid vec allocations
Browse files Browse the repository at this point in the history
  • Loading branch information
sagar-a16z committed Jan 24, 2025
1 parent 0b09b13 commit 6bc9210
Show file tree
Hide file tree
Showing 10 changed files with 89 additions and 72 deletions.
6 changes: 1 addition & 5 deletions jolt-core/benches/commit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,7 @@ fn benchmark_commit<PCS, F, ProofTranscript>(
&format!("{} Commit(mode:{:?}): {}% Ones", name, mode, threshold),
|b| {
b.iter(|| {
PCS::batch_commit(
&leaves.iter().collect::<Vec<_>>(),
&setup,
batch_type.clone(),
);
PCS::batch_commit(&leaves, &setup, batch_type.clone());
});
},
);
Expand Down
6 changes: 1 addition & 5 deletions jolt-core/src/jolt/vm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -276,11 +276,7 @@ impl<F: JoltField> JoltPolynomials<F> {
|| PCS::commit(&self.read_write_memory.t_final, &preprocessing.generators)
);
commitments.instruction_lookups.final_cts = PCS::batch_commit(
&self
.instruction_lookups
.final_cts
.iter()
.collect::<Vec<_>>(),
&self.instruction_lookups.final_cts,
&preprocessing.generators,
BatchType::Big,
);
Expand Down
2 changes: 1 addition & 1 deletion jolt-core/src/lasso/surge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ where
.zip(trace_comitments.into_iter())
.for_each(|(dest, src)| *dest = src);
commitments.final_cts = PCS::batch_commit(
&polynomials.final_cts.iter().collect::<Vec<_>>(),
&polynomials.final_cts,
generators,
BatchType::SurgeInitFinal,
);
Expand Down
60 changes: 31 additions & 29 deletions jolt-core/src/msm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use ark_std::vec::Vec;
use icicle_core::curve::Affine;
use num_integer::Integer;
use rayon::prelude::*;
use std::borrow::Borrow;

pub(crate) mod icicle;
use crate::field::JoltField;
Expand Down Expand Up @@ -218,12 +219,15 @@ where
}

#[tracing::instrument(skip_all)]
fn batch_msm(
fn batch_msm<P>(
bases: &[Self::MulBase],
gpu_bases: Option<&[GpuBaseType<Self>]>,
polys: &[&MultilinearPolynomial<Self::ScalarField>],
) -> Vec<Self> {
assert!(polys.par_iter().all(|s| s.len() == bases.len()));
polys: &[P],
) -> Vec<Self>
where
P: Borrow<MultilinearPolynomial<Self::ScalarField>> + Sync,
{
assert!(polys.par_iter().all(|s| s.borrow().len() == bases.len()));
#[cfg(not(feature = "icicle"))]
assert!(gpu_bases.is_none());
assert_eq!(bases.len(), gpu_bases.map_or(bases.len(), |b| b.len()));
Expand All @@ -235,7 +239,7 @@ where
let _guard = span.enter();
return polys
.into_par_iter()
.map(|poly| Self::msm(bases, None, poly, None).unwrap())
.map(|poly| Self::msm(bases, None, poly.borrow(), None).unwrap())
.collect();
}

Expand All @@ -246,15 +250,16 @@ where
polys
.par_iter()
.enumerate()
.partition_map(|(i, poly)| match poly {
.partition_map(|(i, poly)| match poly.borrow() {
MultilinearPolynomial::LargeScalars(_) => {
let max_num_bits = poly.max_num_bits();
// Use GPU for large-scalar polynomials
Either::Right((i, max_num_bits, *poly))
let max_num_bits = poly.borrow().max_num_bits();
let poly: &DensePolynomial<Self::ScalarField> =
poly.borrow().try_into().unwrap();
Either::Right((i, max_num_bits, poly.evals_ref()))
}
_ => {
let max_num_bits = poly.max_num_bits();
Either::Left((i, max_num_bits, *poly))
let max_num_bits = poly.borrow().max_num_bits();
Either::Left((i, max_num_bits, poly.borrow()))
}
});
drop(_guard);
Expand Down Expand Up @@ -305,16 +310,8 @@ where
.unzip();

let max_num_bits = max_num_bits.iter().max().unwrap();
let scalars: Vec<_> = chunk_polys
.into_iter()
.map(|poly| {
let poly: &DensePolynomial<Self::ScalarField> =
poly.try_into().unwrap();
poly.evals_ref()
})
.collect();
let batch_results =
icicle_batch_msm(gpu_bases, &scalars, *max_num_bits as usize);
icicle_batch_msm(gpu_bases, &chunk_polys, *max_num_bits as usize);

// Store GPU results using original indices
for ((original_idx, _, _), result) in work_chunk.iter().zip(batch_results) {
Expand All @@ -330,14 +327,18 @@ where
results
}

// implement variable msm batch based on the function above
// a "batch" msm that can handle scalars of different sizes
// it mostly amortizes copy costs of sending the generators to the GPU
#[tracing::instrument(skip_all)]
fn variable_batch_msm(
fn variable_batch_msm<P>(
bases: &[Self::MulBase],
gpu_bases: Option<&[GpuBaseType<Self>]>,
polys: &[&MultilinearPolynomial<Self::ScalarField>],
) -> Vec<Self> {
assert!(polys.par_iter().all(|s| s.len() >= bases.len()));
polys: &[P],
) -> Vec<Self>
where
P: Borrow<MultilinearPolynomial<Self::ScalarField>> + Sync,
{
assert!(polys.par_iter().all(|s| s.borrow().len() >= bases.len()));
#[cfg(not(feature = "icicle"))]
assert!(gpu_bases.is_none());

Expand All @@ -348,7 +349,7 @@ where
let _guard = span.enter();
return polys
.into_par_iter()
.map(|poly| Self::msm(bases, None, poly, None).unwrap())
.map(|poly| Self::msm(bases, None, poly.borrow(), None).unwrap())
.collect();
}

Expand All @@ -357,12 +358,13 @@ where
let _guard = span.enter();
let (cpu_batch, gpu_batch): (Vec<_>, Vec<_>) =
polys.par_iter().enumerate().partition_map(|(i, poly)| {
let max_num_bits = poly.max_num_bits();
let max_num_bits = poly.borrow().max_num_bits();
if use_icicle && max_num_bits > 10 {
let poly: &DensePolynomial<Self::ScalarField> = (*poly).try_into().unwrap();
let poly: &DensePolynomial<Self::ScalarField> =
poly.borrow().try_into().unwrap();
Either::Right((i, max_num_bits, poly.evals_ref()))
} else {
Either::Left((i, max_num_bits, *poly))
Either::Left((i, max_num_bits, poly.borrow()))
}
});
drop(_guard);
Expand Down
10 changes: 7 additions & 3 deletions jolt-core/src/poly/commitment/binius.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::poly::multilinear_polynomial::MultilinearPolynomial;
use crate::utils::errors::ProofVerifyError;
use crate::utils::transcript::{AppendToTranscript, Transcript};
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
use std::borrow::Borrow;
use std::marker::PhantomData;

#[derive(Clone)]
Expand Down Expand Up @@ -50,11 +51,14 @@ impl<ProofTranscript: Transcript> CommitmentScheme<ProofTranscript>
) -> Self::Commitment {
todo!()
}
fn batch_commit(
_polys: &[&MultilinearPolynomial<Self::Field>],
fn batch_commit<P>(
_polys: &[P],
_gens: &Self::Setup,
_batch_type: BatchType,
) -> Vec<Self::Commitment> {
) -> Vec<Self::Commitment>
where
P: Borrow<MultilinearPolynomial<Self::Field>>,
{
todo!()
}
fn prove(
Expand Down
9 changes: 6 additions & 3 deletions jolt-core/src/poly/commitment/commitment_scheme.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
use std::borrow::Borrow;
use std::fmt::Debug;

use crate::utils::transcript::Transcript;
Expand Down Expand Up @@ -48,11 +49,13 @@ pub trait CommitmentScheme<ProofTranscript: Transcript>: Clone + Sync + Send + '

fn setup(shapes: &[CommitShape]) -> Self::Setup;
fn commit(poly: &MultilinearPolynomial<Self::Field>, setup: &Self::Setup) -> Self::Commitment;
fn batch_commit(
polys: &[&MultilinearPolynomial<Self::Field>],
fn batch_commit<U>(
polys: &[U],
gens: &Self::Setup,
batch_type: BatchType,
) -> Vec<Self::Commitment>;
) -> Vec<Self::Commitment>
where
U: Borrow<MultilinearPolynomial<Self::Field>> + Sync;

/// Homomorphically combines multiple commitments into a single commitment, computed as a
/// linear combination with the given coefficients.
Expand Down
16 changes: 8 additions & 8 deletions jolt-core/src/poly/commitment/hyperkzg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use rayon::iter::{
IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator,
IntoParallelRefMutIterator, ParallelIterator,
};
use std::borrow::Borrow;
use std::{marker::PhantomData, sync::Arc};

pub struct HyperKZGSRS<P: Pairing>(Arc<SRS<P>>)
Expand Down Expand Up @@ -323,11 +324,7 @@ where
assert_eq!(polys[ell - 1].len(), 2);

// We do not need to commit to the first polynomial as it is already committed.
// Compute commitments in parallel
// TODO: This could be done by batch too if it gets progressively smaller.
let com: Vec<P::G1Affine> = into_optimal_iter!(1..polys.len())
.map(|i| UnivariateKZG::commit_as_univariate(&pk.kzg_pk, &polys[i]).unwrap())
.collect();
let com: Vec<P::G1Affine> = UnivariateKZG::commit_variable_batch(&pk.kzg_pk, &polys[1..])?;

// Phase 2
// We do not need to add x to the transcript, because in our context x was obtained from the transcript.
Expand Down Expand Up @@ -439,11 +436,14 @@ where
}

#[tracing::instrument(skip_all, name = "HyperKZG::batch_commit")]
fn batch_commit(
polys: &[&MultilinearPolynomial<Self::Field>],
fn batch_commit<U>(
polys: &[U],
gens: &Self::Setup,
_batch_type: BatchType,
) -> Vec<Self::Commitment> {
) -> Vec<Self::Commitment>
where
U: Borrow<MultilinearPolynomial<Self::Field>> + Sync,
{
UnivariateKZG::commit_batch(&gens.0.kzg_pk, polys)
.unwrap()
.into_par_iter()
Expand Down
29 changes: 19 additions & 10 deletions jolt-core/src/poly/commitment/kzg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use ark_ff::PrimeField;
use ark_std::{One, UniformRand, Zero};
use rand_core::{CryptoRng, RngCore};
use rayon::prelude::*;
use std::borrow::Borrow;
use std::marker::PhantomData;
use std::sync::Arc;

Expand Down Expand Up @@ -194,25 +195,33 @@ where
P::G1: Icicle,
{
#[tracing::instrument(skip_all, name = "KZG::commit_batch")]
pub fn commit_batch(
pub fn commit_batch<U>(
pk: &KZGProverKey<P>,
polys: &[&MultilinearPolynomial<P::ScalarField>],
) -> Result<Vec<P::G1Affine>, ProofVerifyError> {
polys: &[U],
) -> Result<Vec<P::G1Affine>, ProofVerifyError>
where
U: Borrow<MultilinearPolynomial<P::ScalarField>> + Sync,
{
let g1_powers = &pk.g1_powers();
let gpu_g1 = pk.gpu_g1();

// batch commit requires all batches to have the same length
assert!(polys.par_iter().all(|s| s.len() == polys[0].len()));
assert!(polys[0].len() <= g1_powers.len());

if let Some(invalid) = polys.iter().find(|coeffs| coeffs.len() > g1_powers.len()) {
assert!(polys
.par_iter()
.all(|s| s.borrow().len() == polys[0].borrow().len()));
assert!(polys[0].borrow().len() <= g1_powers.len());

if let Some(invalid) = polys
.iter()
.find(|coeffs| (*coeffs).borrow().len() > g1_powers.len())
{
return Err(ProofVerifyError::KeyLengthError(
g1_powers.len(),
invalid.len(),
invalid.borrow().len(),
));
}

let msm_size = polys[0].len();
let msm_size = polys[0].borrow().len();
let commitments = <P::G1 as VariableBaseMSM>::batch_msm(
&g1_powers[..msm_size],
gpu_g1.map(|g| &g[..msm_size]),
Expand All @@ -225,7 +234,7 @@ where
#[tracing::instrument(skip_all, name = "KZG::commit_variable_batch_with_mode")]
pub fn commit_variable_batch(
pk: &KZGProverKey<P>,
polys: &[&MultilinearPolynomial<P::ScalarField>],
polys: &[MultilinearPolynomial<P::ScalarField>],
) -> Result<Vec<P::G1Affine>, ProofVerifyError> {
let g1_powers = &pk.g1_powers();
let gpu_g1 = pk.gpu_g1();
Expand Down
10 changes: 7 additions & 3 deletions jolt-core/src/poly/commitment/mock.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::borrow::Borrow;
use std::marker::PhantomData;

use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
Expand Down Expand Up @@ -49,11 +50,14 @@ where
fn commit(poly: &MultilinearPolynomial<Self::Field>, _setup: &Self::Setup) -> Self::Commitment {
MockCommitment { poly: poly.clone() }
}
fn batch_commit(
polys: &[&MultilinearPolynomial<Self::Field>],
fn batch_commit<P>(
polys: &[P],
setup: &Self::Setup,
_batch_type: BatchType,
) -> Vec<Self::Commitment> {
) -> Vec<Self::Commitment>
where
P: Borrow<MultilinearPolynomial<Self::Field>>,
{
polys
.into_iter()
.map(|poly| Self::commit(poly, setup))

Check failure on line 63 in jolt-core/src/poly/commitment/mock.rs

View workflow job for this annotation

GitHub Actions / test

mismatched types
Expand Down
13 changes: 8 additions & 5 deletions jolt-core/src/poly/commitment/zeromorph.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
#![allow(clippy::too_many_arguments)]
#![allow(clippy::type_complexity)]

use std::{iter, marker::PhantomData};

use crate::msm::{use_icicle, Icicle, VariableBaseMSM};
use crate::poly::multilinear_polynomial::{MultilinearPolynomial, PolynomialEvaluation};
use crate::poly::{dense_mlpoly::DensePolynomial, unipoly::UniPoly};
Expand All @@ -17,7 +15,9 @@ use ark_std::{One, Zero};
use itertools::izip;
use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng};
use rand_core::{CryptoRng, RngCore};
use std::borrow::Borrow;
use std::sync::Arc;
use std::{iter, marker::PhantomData};

use super::{
commitment_scheme::{BatchType, CommitShape, CommitmentScheme},
Expand Down Expand Up @@ -456,11 +456,14 @@ where
ZeromorphCommitment(UnivariateKZG::commit_as_univariate(&setup.0.commit_pp, poly).unwrap())
}

fn batch_commit(
polys: &[&MultilinearPolynomial<Self::Field>],
fn batch_commit<U>(
polys: &[U],
gens: &Self::Setup,
_batch_type: BatchType,
) -> Vec<Self::Commitment> {
) -> Vec<Self::Commitment>
where
U: Borrow<MultilinearPolynomial<Self::Field>> + Sync,
{
UnivariateKZG::commit_batch(&gens.0.commit_pp, polys)
.unwrap()
.into_iter()
Expand Down

0 comments on commit 6bc9210

Please sign in to comment.