diff --git a/Cargo.lock b/Cargo.lock index aefc84ab4..031c688dd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1558,7 +1558,7 @@ dependencies = [ [[package]] name = "icicle-bn254" version = "3.2.0" -source = "git+https://github.com/ingonyama-zk/icicle-jolt.git?rev=ed93e21#ed93e21cbb405822b0aa1b58b5dc6c7837a04108" +source = "git+https://github.com/ingonyama-zk/icicle-jolt.git?rev=94fe8ca#94fe8cad314a716c6a76c526cab7683da798bc8d" dependencies = [ "cmake", "icicle-core", @@ -1569,7 +1569,7 @@ dependencies = [ [[package]] name = "icicle-core" version = "3.2.0" -source = "git+https://github.com/ingonyama-zk/icicle-jolt.git?rev=ed93e21#ed93e21cbb405822b0aa1b58b5dc6c7837a04108" +source = "git+https://github.com/ingonyama-zk/icicle-jolt.git?rev=94fe8ca#94fe8cad314a716c6a76c526cab7683da798bc8d" dependencies = [ "hex", "icicle-runtime", @@ -1581,7 +1581,7 @@ dependencies = [ [[package]] name = "icicle-hash" version = "3.2.0" -source = "git+https://github.com/ingonyama-zk/icicle-jolt.git?rev=ed93e21#ed93e21cbb405822b0aa1b58b5dc6c7837a04108" +source = "git+https://github.com/ingonyama-zk/icicle-jolt.git?rev=94fe8ca#94fe8cad314a716c6a76c526cab7683da798bc8d" dependencies = [ "cmake", "icicle-core", @@ -1592,7 +1592,7 @@ dependencies = [ [[package]] name = "icicle-runtime" version = "3.2.0" -source = "git+https://github.com/ingonyama-zk/icicle-jolt.git?rev=ed93e21#ed93e21cbb405822b0aa1b58b5dc6c7837a04108" +source = "git+https://github.com/ingonyama-zk/icicle-jolt.git?rev=94fe8ca#94fe8cad314a716c6a76c526cab7683da798bc8d" dependencies = [ "cmake", "once_cell", diff --git a/jolt-core/Cargo.toml b/jolt-core/Cargo.toml index 7df78f525..0d9a37915 100644 --- a/jolt-core/Cargo.toml +++ b/jolt-core/Cargo.toml @@ -124,16 +124,23 @@ name = "jolt_core" path = "src/lib.rs" [target.'cfg(not(target_arch = "wasm32"))'.dependencies] +memory-stats = "1.0.0" +sys-info = "0.9.1" +tokio = { version = "1.38.0", optional = true, features = ["rt-multi-thread"] } + +[target.'cfg(all(not(target_arch = "wasm32"), target_os = "macos"))'.dependencies] +icicle-core = { git = "https://github.com/ingonyama-zk/icicle-jolt.git", rev = "94fe8ca", optional = true } +icicle-runtime = { git = "https://github.com/ingonyama-zk/icicle-jolt.git", rev = "94fe8ca", optional = true } +icicle-bn254 = { git = "https://github.com/ingonyama-zk/icicle-jolt.git", rev = "94fe8ca", optional = true } + +[target.'cfg(all(not(target_arch = "wasm32"), not(target_os = "macos")))'.dependencies] +icicle-core = { git = "https://github.com/ingonyama-zk/icicle-jolt.git", rev = "94fe8ca", optional = true } icicle-runtime = { git = "https://github.com/ingonyama-zk/icicle-jolt.git", features = [ "cuda_backend", -], rev = "ed93e21", optional = true } -icicle-core = { git = "https://github.com/ingonyama-zk/icicle-jolt.git", rev = "ed93e21", optional = true } +], rev = "94fe8ca", optional = true } icicle-bn254 = { git = "https://github.com/ingonyama-zk/icicle-jolt.git", features = [ "cuda_backend", -], rev = "ed93e21", optional = true } -memory-stats = "1.0.0" -sys-info = "0.9.1" -tokio = { version = "1.38.0", optional = true, features = ["rt-multi-thread"] } +], rev = "94fe8ca", optional = true } [target.'cfg(target_arch = "wasm32")'.dependencies] getrandom = { version = "0.2", features = ["js"] } diff --git a/jolt-core/benches/commit.rs b/jolt-core/benches/commit.rs index 86a2e6a70..6e3a7b784 100644 --- a/jolt-core/benches/commit.rs +++ b/jolt-core/benches/commit.rs @@ -89,11 +89,7 @@ fn benchmark_commit( &format!("{} Commit(mode:{:?}): {}% Ones", name, mode, threshold), |b| { b.iter(|| { - PCS::batch_commit( - &leaves.iter().collect::>(), - &setup, - batch_type.clone(), - ); + PCS::batch_commit(&leaves, &setup, batch_type.clone()); }); }, ); diff --git a/jolt-core/benches/msm.rs b/jolt-core/benches/msm.rs index f314d7302..497c8e1f5 100644 --- a/jolt-core/benches/msm.rs +++ b/jolt-core/benches/msm.rs @@ -101,6 +101,9 @@ where } fn main() { + let small_value_lookup_tables = ::compute_lookup_tables(); + ::initialize_lookup_tables(small_value_lookup_tables); + let mut criterion = Criterion::default() .configure_from_args() .sample_size(20) diff --git a/jolt-core/benches/msm_batch.rs b/jolt-core/benches/msm_batch.rs index 3d94edc0f..19d4db4ff 100644 --- a/jolt-core/benches/msm_batch.rs +++ b/jolt-core/benches/msm_batch.rs @@ -60,34 +60,55 @@ fn random_poly(max_num_bits: usize, len: usize) -> MultilinearPolynomial { 0 => MultilinearPolynomial::from(vec![0u8; len]), 1..=8 => MultilinearPolynomial::from( (0..len) - .into_iter() - .map(|_| (rng.next_u32() & ((1 << max_num_bits) - 1)) as u8) + .map(|_| { + let mask = if max_num_bits == 8 { + u8::MAX + } else { + (1u8 << max_num_bits) - 1 + }; + (rng.next_u32() & (mask as u32)) as u8 + }) .collect::>(), ), 9..=16 => MultilinearPolynomial::from( (0..len) - .into_iter() - .map(|_| (rng.next_u32() & ((1 << max_num_bits) - 1)) as u16) + .map(|_| { + let mask = if max_num_bits == 16 { + u16::MAX + } else { + (1u16 << max_num_bits) - 1 + }; + (rng.next_u32() & (mask as u32)) as u16 + }) .collect::>(), ), 17..=32 => MultilinearPolynomial::from( (0..len) - .into_iter() - .map(|_| (rng.next_u64() & ((1 << max_num_bits) - 1)) as u32) + .map(|_| { + let mask = if max_num_bits == 32 { + u32::MAX + } else { + (1u32 << max_num_bits) - 1 + }; + (rng.next_u64() & (mask as u64)) as u32 + }) .collect::>(), ), 33..=64 => MultilinearPolynomial::from( (0..len) - .into_iter() - .map(|_| rng.next_u64() & ((1 << max_num_bits) - 1)) - .collect::>(), - ), - _ => MultilinearPolynomial::from( - (0..len) - .into_iter() - .map(|_| Fr::random(&mut rng)) + .map(|_| { + let mask = if max_num_bits == 64 { + u64::MAX + } else { + (1u64 << max_num_bits) - 1 + }; + rng.next_u64() & mask + }) .collect::>(), ), + _ => { + MultilinearPolynomial::from((0..len).map(|_| Fr::random(&mut rng)).collect::>()) + } } } @@ -120,6 +141,9 @@ fn benchmark_msm_batch( } fn main() { + let small_value_lookup_tables = ::compute_lookup_tables(); + ::initialize_lookup_tables(small_value_lookup_tables); + let mut criterion = Criterion::default() .configure_from_args() .sample_size(10) diff --git a/jolt-core/src/jolt/vm/mod.rs b/jolt-core/src/jolt/vm/mod.rs index e4d5eb595..b4b9232f0 100644 --- a/jolt-core/src/jolt/vm/mod.rs +++ b/jolt-core/src/jolt/vm/mod.rs @@ -276,11 +276,7 @@ impl JoltPolynomials { || PCS::commit(&self.read_write_memory.t_final, &preprocessing.generators) ); commitments.instruction_lookups.final_cts = PCS::batch_commit( - &self - .instruction_lookups - .final_cts - .iter() - .collect::>(), + &self.instruction_lookups.final_cts, &preprocessing.generators, BatchType::Big, ); diff --git a/jolt-core/src/lasso/surge.rs b/jolt-core/src/lasso/surge.rs index cd0cd567f..36cd83c58 100644 --- a/jolt-core/src/lasso/surge.rs +++ b/jolt-core/src/lasso/surge.rs @@ -424,7 +424,7 @@ where .zip(trace_comitments.into_iter()) .for_each(|(dest, src)| *dest = src); commitments.final_cts = PCS::batch_commit( - &polynomials.final_cts.iter().collect::>(), + &polynomials.final_cts, generators, BatchType::SurgeInitFinal, ); diff --git a/jolt-core/src/msm/icicle/adapter.rs b/jolt-core/src/msm/icicle/adapter.rs index 6196b3bd6..e0898a13d 100644 --- a/jolt-core/src/msm/icicle/adapter.rs +++ b/jolt-core/src/msm/icicle/adapter.rs @@ -131,6 +131,97 @@ where V::to_ark_projective(&msm_host_result[0]) } +// batch_info is a tuple of (batch_id, bit_size, scalars) +pub type BatchInfo<'a, V: VariableBaseMSM> = (usize, usize, &'a [V::ScalarField]); + +#[tracing::instrument(skip_all, name = "icicle_variable_batch_msm")] +pub fn icicle_variable_batch_msm( + bases: &[GpuBaseType], + batch_info: &[BatchInfo], +) -> Vec<(usize, V)> +where + V: VariableBaseMSM, + V::ScalarField: JoltField, +{ + let mut stream = IcicleStream::create().unwrap(); + + let mut bases_slice = + DeviceVec::>::device_malloc_async(bases.len(), &stream).unwrap(); + let span = tracing::span!(tracing::Level::INFO, "copy_bases_to_gpu"); + let _guard = span.enter(); + bases_slice + .copy_from_host_async(HostSlice::from_slice(bases), &stream) + .unwrap(); + drop(_guard); + drop(span); + + let num_batches = batch_info.len(); + let mut msm_result = + DeviceVec::>::device_malloc_async(num_batches, &stream).unwrap(); + let mut msm_host_results = vec![Projective::::zero(); num_batches]; + + for (index, (_batch_id, bit_size, scalars)) in batch_info.iter().enumerate() { + let span = tracing::span!(tracing::Level::INFO, "convert_scalars"); + let _guard = span.enter(); + + let mut scalars_slice = + DeviceVec::<<::C as Curve>::ScalarField>::device_malloc_async( + scalars.len(), + &stream, + ) + .unwrap(); + let scalars_mont = unsafe { + &*(&scalars[..] as *const _ as *const [<::C as Curve>::ScalarField]) + }; + drop(_guard); + drop(span); + + let span = tracing::span!(tracing::Level::INFO, "copy_scalars_to_gpu"); + let _guard = span.enter(); + scalars_slice + .copy_from_host_async(HostSlice::from_slice(scalars_mont), &stream) + .unwrap(); + drop(_guard); + drop(span); + + let mut cfg = MSMConfig::default(); + cfg.stream_handle = IcicleStreamHandle::from(&stream); + cfg.is_async = true; + cfg.are_scalars_montgomery_form = true; + cfg.bitsize = *bit_size as i32; + + let span = tracing::span!(tracing::Level::INFO, "gpu_msm"); + let _guard = span.enter(); + + msm( + &scalars_slice, + &bases_slice[..scalars.len()], + &cfg, + &mut msm_result[index..index + 1], + ) + .unwrap(); + + drop(_guard); + drop(span); + } + + let span = tracing::span!(tracing::Level::INFO, "copy_msm_result"); + let _guard = span.enter(); + msm_result + .copy_to_host(HostSlice::from_mut_slice(&mut msm_host_results)) + .unwrap(); + drop(_guard); + drop(span); + + stream.synchronize().unwrap(); + stream.destroy().unwrap(); + batch_info + .par_iter() + .zip(msm_host_results) + .map(|((batch_id, _, _), result)| (*batch_id, V::to_ark_projective(&result))) + .collect() +} + /// Batch process msms - assumes batches are equal in size /// Variable Batch sizes is not currently supported by icicle #[tracing::instrument(skip_all)] diff --git a/jolt-core/src/msm/mod.rs b/jolt-core/src/msm/mod.rs index 21fa59182..2da5256f2 100644 --- a/jolt-core/src/msm/mod.rs +++ b/jolt-core/src/msm/mod.rs @@ -6,11 +6,10 @@ use ark_std::vec::Vec; use icicle_core::curve::Affine; use num_integer::Integer; use rayon::prelude::*; +use std::borrow::Borrow; pub(crate) mod icicle; use crate::field::JoltField; -#[cfg(feature = "icicle")] -use crate::poly::dense_mlpoly::DensePolynomial; use crate::poly::multilinear_polynomial::MultilinearPolynomial; use crate::utils::errors::ProofVerifyError; use crate::utils::math::Math; @@ -23,6 +22,7 @@ pub type GpuBaseType = Affine; #[cfg(not(feature = "icicle"))] pub type GpuBaseType = G::MulBase; +use crate::poly::unipoly::UniPoly; use itertools::Either; /// Copy of ark_ec::VariableBaseMSM with minor modifications to speed up @@ -219,24 +219,39 @@ where } #[tracing::instrument(skip_all)] - fn batch_msm( + fn batch_msm_common

( bases: &[Self::MulBase], gpu_bases: Option<&[GpuBaseType]>, - polys: &[&MultilinearPolynomial], - ) -> Vec { - assert!(polys.par_iter().all(|s| s.len() == bases.len())); + polys: &[P], + variable_batches: bool, + ) -> Vec + where + P: Borrow> + Sync, + { + // Validate input lengths + if variable_batches { + assert!(polys.par_iter().all(|s| s.borrow().len() <= bases.len())); + } else { + assert!(polys.par_iter().all(|s| s.borrow().len() == bases.len())); + assert_eq!(bases.len(), gpu_bases.map_or(bases.len(), |b| b.len())); + } + #[cfg(not(feature = "icicle"))] assert!(gpu_bases.is_none()); - assert_eq!(bases.len(), gpu_bases.map_or(bases.len(), |b| b.len())); let use_icicle = use_icicle(); + // Handle CPU-only case if !use_icicle { let span = tracing::span!(tracing::Level::INFO, "batch_msm_cpu_only"); let _guard = span.enter(); return polys .into_par_iter() - .map(|poly| Self::msm(bases, None, poly, None).unwrap()) + .map(|poly| { + let poly = poly.borrow(); + let bases_slice = &bases[..poly.len()]; + Self::msm(bases_slice, None, poly, None).unwrap() + }) .collect(); } @@ -244,20 +259,35 @@ where let span = tracing::span!(tracing::Level::INFO, "group_scalar_indices_parallel"); let _guard = span.enter(); let (cpu_batch, gpu_batch): (Vec<_>, Vec<_>) = - polys - .par_iter() - .enumerate() - .partition_map(|(i, poly)| match poly { - MultilinearPolynomial::LargeScalars(_) => { - let max_num_bits = poly.max_num_bits(); - // Use GPU for large-scalar polynomials - Either::Right((i, max_num_bits, *poly)) - } - _ => { - let max_num_bits = poly.max_num_bits(); - Either::Left((i, max_num_bits, *poly)) + polys.par_iter().enumerate().partition_map(|(i, poly)| { + let poly = poly.borrow(); + let max_num_bits = poly.max_num_bits(); + + if max_num_bits > 10 { + match poly { + MultilinearPolynomial::LargeScalars(poly) => { + Either::Right((i, max_num_bits, poly.evals())) + } + MultilinearPolynomial::U16Scalars(poly) => { + Either::Right((i, max_num_bits, poly.coeffs_as_field_elements())) + } + MultilinearPolynomial::U32Scalars(poly) => { + Either::Right((i, max_num_bits, poly.coeffs_as_field_elements())) + } + MultilinearPolynomial::U64Scalars(poly) => { + Either::Right((i, max_num_bits, poly.coeffs_as_field_elements())) + } + MultilinearPolynomial::I64Scalars(poly) => { + Either::Right((i, max_num_bits, poly.coeffs_as_field_elements())) + } + MultilinearPolynomial::U8Scalars(_) => unreachable!( + "MultilinearPolynomial::U8Scalars cannot have more than 10 bits" + ), } - }); + } else { + Either::Left((i, max_num_bits, poly)) + } + }); drop(_guard); drop(span); let mut results = vec![Self::zero(); polys.len()]; @@ -268,9 +298,10 @@ where let cpu_results: Vec<(usize, Self)> = cpu_batch .into_par_iter() .map(|(i, max_num_bits, poly)| { + let bases_slice = &bases[..poly.len()]; ( i, - Self::msm(bases, None, poly, Some(max_num_bits as usize)).unwrap(), + Self::msm(bases_slice, None, poly, Some(max_num_bits)).unwrap(), ) }) .collect(); @@ -294,33 +325,165 @@ where &backup }); - // includes putting the scalars and bases on device - let slice_bit_size = 256 * gpu_batch[0].2.len() * 2; - let slices_at_a_time = total_memory_bits() / slice_bit_size; - - // Process GPU batches with memory constraints - for work_chunk in gpu_batch.chunks(slices_at_a_time) { - let (max_num_bits, chunk_polys): (Vec<_>, Vec<_>) = work_chunk - .par_iter() - .map(|(_, max_num_bits, poly)| (*max_num_bits, *poly)) - .unzip(); - - let max_num_bits = max_num_bits.iter().max().unwrap(); - let scalars: Vec<_> = chunk_polys - .into_iter() - .map(|poly| { - let poly: &DensePolynomial = - poly.try_into().unwrap(); - poly.evals_ref() - }) - .collect(); - let batch_results = - icicle_batch_msm(gpu_bases, &scalars, *max_num_bits as usize); - - // Store GPU results using original indices - for ((original_idx, _, _), result) in work_chunk.iter().zip(batch_results) { - results[*original_idx] = result; + if variable_batches { + // Variable-length batch processing + let batch = gpu_batch + .iter() + .map(|(i, max_num_bits, poly)| (*i, *max_num_bits, poly.as_slice())) + .collect::>(); + let batched_results = icicle_variable_batch_msm(gpu_bases, &batch); + for (index, result) in batched_results { + results[index] = result; } + } else { + // Fixed-length batch processing + let slice_bit_size = 256 * gpu_batch[0].2.len() * 2; + let slices_at_a_time = total_memory_bits() / slice_bit_size; + + for work_chunk in gpu_batch.chunks(slices_at_a_time) { + let (max_num_bits, chunk_polys): (Vec<_>, Vec<_>) = work_chunk + .par_iter() + .map(|(_, max_num_bits, poly)| (*max_num_bits, poly.as_slice())) + .unzip(); + + let max_num_bits = max_num_bits.iter().max().unwrap(); + let batch_results = + icicle_batch_msm(gpu_bases, &chunk_polys, *max_num_bits); + + for ((index, _, _), result) in work_chunk.iter().zip(batch_results) { + results[*index] = result; + } + } + } + } + #[cfg(not(feature = "icicle"))] + { + unreachable!("icicle_init must not return true without the icicle feature"); + } + } + results + } + + #[tracing::instrument(skip_all)] + fn batch_msm

( + bases: &[Self::MulBase], + gpu_bases: Option<&[GpuBaseType]>, + polys: &[P], + ) -> Vec + where + P: Borrow> + Sync, + { + Self::batch_msm_common(bases, gpu_bases, polys, false) + } + + // a "batch" msm that can handle scalars of different sizes + // it mostly amortizes copy costs of sending the generators to the GPU + #[tracing::instrument(skip_all)] + fn variable_batch_msm

( + bases: &[Self::MulBase], + gpu_bases: Option<&[GpuBaseType]>, + polys: &[P], + ) -> Vec + where + P: Borrow> + Sync, + { + Self::batch_msm_common(bases, gpu_bases, polys, true) + } + + #[tracing::instrument(skip_all)] + fn variable_batch_msm_univariate

( + bases: &[Self::MulBase], + gpu_bases: Option<&[GpuBaseType]>, + polys: &[P], + ) -> Vec + where + P: Borrow> + Sync, + { + assert!(polys + .par_iter() + .all(|s| s.borrow().coeffs.len() <= bases.len())); + #[cfg(not(feature = "icicle"))] + assert!(gpu_bases.is_none()); + + let use_icicle = use_icicle(); + + if !use_icicle { + let span = tracing::span!(tracing::Level::INFO, "batch_msm_cpu_only"); + let _guard = span.enter(); + return polys + .into_par_iter() + .map(|poly| { + Self::msm_field_elements( + &bases[..poly.borrow().coeffs.len()], + None, + &poly.borrow().coeffs, + None, + false, + ) + .unwrap() + }) + .collect(); + } + + // Split scalar batches into CPU and GPU workloads + let span = tracing::span!(tracing::Level::INFO, "group_scalar_indices_parallel"); + let _guard = span.enter(); + let (cpu_batch, gpu_batch): (Vec<_>, Vec<_>) = + polys.par_iter().enumerate().partition_map(|(i, poly)| { + let poly = poly.borrow(); + let max_num_bits = (*poly.coeffs.iter().max().unwrap()).num_bits() as usize; + if use_icicle && max_num_bits > 10 { + Either::Right((i, max_num_bits, poly.coeffs.as_slice())) + } else { + Either::Left((i, max_num_bits, poly)) + } + }); + drop(_guard); + drop(span); + let mut results = vec![Self::zero(); polys.len()]; + + // Handle CPU computations in parallel + let span = tracing::span!(tracing::Level::INFO, "batch_msm_cpu"); + let _guard = span.enter(); + let cpu_results: Vec<(usize, Self)> = cpu_batch + .into_par_iter() + .map(|(i, max_num_bits, poly)| { + ( + i, + Self::msm_field_elements( + &bases[..poly.borrow().coeffs.len()], + None, + &poly.borrow().coeffs, + Some(max_num_bits), + false, + ) + .unwrap(), + ) + }) + .collect(); + drop(_guard); + drop(span); + + // Store CPU results + for (i, result) in cpu_results { + results[i] = result; + } + + // Handle GPU computations if available + if !gpu_batch.is_empty() && use_icicle { + #[cfg(feature = "icicle")] + { + let span = tracing::span!(tracing::Level::INFO, "batch_msms_gpu"); + let _guard = span.enter(); + let mut backup = vec![]; + let gpu_bases = gpu_bases.unwrap_or_else(|| { + backup = Self::get_gpu_bases(bases); + &backup + }); + + let batched_results = icicle_variable_batch_msm(gpu_bases, &gpu_batch); + for (index, result) in batched_results { + results[index] = result; } } #[cfg(not(feature = "icicle"))] diff --git a/jolt-core/src/poly/commitment/binius.rs b/jolt-core/src/poly/commitment/binius.rs index 9efb5c694..c8009947c 100644 --- a/jolt-core/src/poly/commitment/binius.rs +++ b/jolt-core/src/poly/commitment/binius.rs @@ -7,6 +7,7 @@ use crate::poly::multilinear_polynomial::MultilinearPolynomial; use crate::utils::errors::ProofVerifyError; use crate::utils::transcript::{AppendToTranscript, Transcript}; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; +use std::borrow::Borrow; use std::marker::PhantomData; #[derive(Clone)] @@ -50,11 +51,14 @@ impl CommitmentScheme ) -> Self::Commitment { todo!() } - fn batch_commit( - _polys: &[&MultilinearPolynomial], + fn batch_commit

( + _polys: &[P], _gens: &Self::Setup, _batch_type: BatchType, - ) -> Vec { + ) -> Vec + where + P: Borrow>, + { todo!() } fn prove( diff --git a/jolt-core/src/poly/commitment/commitment_scheme.rs b/jolt-core/src/poly/commitment/commitment_scheme.rs index 1e8fc0ec8..603b113ab 100644 --- a/jolt-core/src/poly/commitment/commitment_scheme.rs +++ b/jolt-core/src/poly/commitment/commitment_scheme.rs @@ -1,4 +1,5 @@ use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; +use std::borrow::Borrow; use std::fmt::Debug; use crate::utils::transcript::Transcript; @@ -48,11 +49,13 @@ pub trait CommitmentScheme: Clone + Sync + Send + ' fn setup(shapes: &[CommitShape]) -> Self::Setup; fn commit(poly: &MultilinearPolynomial, setup: &Self::Setup) -> Self::Commitment; - fn batch_commit( - polys: &[&MultilinearPolynomial], + fn batch_commit( + polys: &[U], gens: &Self::Setup, batch_type: BatchType, - ) -> Vec; + ) -> Vec + where + U: Borrow> + Sync; /// Homomorphically combines multiple commitments into a single commitment, computed as a /// linear combination with the given coefficients. diff --git a/jolt-core/src/poly/commitment/hyperkzg.rs b/jolt-core/src/poly/commitment/hyperkzg.rs index bb3b60017..5875106f5 100644 --- a/jolt-core/src/poly/commitment/hyperkzg.rs +++ b/jolt-core/src/poly/commitment/hyperkzg.rs @@ -15,7 +15,6 @@ use crate::field::JoltField; use crate::poly::commitment::commitment_scheme::CommitShape; use crate::poly::multilinear_polynomial::{MultilinearPolynomial, PolynomialEvaluation}; use crate::utils::transcript::Transcript; -use crate::{field, into_optimal_iter}; use crate::{ msm::{Icicle, VariableBaseMSM}, poly::{commitment::kzg::SRS, dense_mlpoly::DensePolynomial, unipoly::UniPoly}, @@ -30,6 +29,7 @@ use rayon::iter::{ IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator, }; +use std::borrow::Borrow; use std::{marker::PhantomData, sync::Arc}; pub struct HyperKZGSRS(Arc>) @@ -101,18 +101,25 @@ pub struct HyperKZGProof { // the quotient of f(x)/(x-u) and (f(x) - f(v))/(x-u) is the // same. One advantage is that computing f(u) could be decoupled // from kzg_open, it could be done later or separate from computing W. -fn kzg_open_no_rem( +fn kzg_batch_open_no_rem( f: &MultilinearPolynomial, - u: P::ScalarField, + u: &[P::ScalarField], pk: &HyperKZGProverKey

, -) -> P::G1Affine +) -> Vec where -

::ScalarField: field::JoltField, +

::ScalarField: JoltField,

::G1: Icicle, { let f: &DensePolynomial = f.try_into().unwrap(); - let h = compute_witness_polynomial::

(&f.evals(), u); - UnivariateKZG::commit(&pk.kzg_pk, &UniPoly::from_coeff(h)).unwrap() + let h = u + .par_iter() + .map(|ui| { + let h = compute_witness_polynomial::

(&f.evals(), *ui); + MultilinearPolynomial::from(h) + }) + .collect::>(); + + UnivariateKZG::commit_batch(&pk.kzg_pk, &h).unwrap() } fn compute_witness_polynomial( @@ -120,7 +127,7 @@ fn compute_witness_polynomial( u: P::ScalarField, ) -> Vec where -

::ScalarField: field::JoltField, +

::ScalarField: JoltField, { let d = f.len(); @@ -140,7 +147,7 @@ fn kzg_open_batch( transcript: &mut ProofTranscript, ) -> (Vec, Vec>) where -

::ScalarField: field::JoltField, +

::ScalarField: JoltField,

::G1: Icicle, { let k = f.len(); @@ -158,14 +165,13 @@ where }); // TODO(moodlezoup): Avoid cloned() - transcript.append_scalars(&v.iter().flatten().cloned().collect::>()); + let scalars = v.iter().flatten().collect::>(); + transcript.append_scalars::(&scalars); let q_powers: Vec = transcript.challenge_scalar_powers(f.len()); let B = MultilinearPolynomial::linear_combination(&f.iter().collect::>(), &q_powers); // Now open B at u0, ..., u_{t-1} - let w = into_optimal_iter!(u) - .map(|ui| kzg_open_no_rem(&B, *ui, pk)) - .collect::>(); + let w = kzg_batch_open_no_rem(&B, u, pk); // The prover computes the challenge to keep the transcript in the same // state as that of the verifier @@ -185,13 +191,14 @@ fn kzg_verify_batch( transcript: &mut ProofTranscript, ) -> bool where -

::ScalarField: field::JoltField, +

::ScalarField: JoltField,

::G1: Icicle, { let k = C.len(); let t = u.len(); - transcript.append_scalars(&v.iter().flatten().cloned().collect::>()); + let scalars = v.iter().flatten().collect::>(); + transcript.append_scalars::(&scalars); let q_powers: Vec = transcript.challenge_scalar_powers(k); transcript.append_points(&W.iter().map(|g| g.into_group()).collect::>()); @@ -268,7 +275,7 @@ pub struct HyperKZG { impl HyperKZG where -

::ScalarField: field::JoltField, +

::ScalarField: JoltField,

::G1: Icicle, { pub fn protocol_name() -> &'static [u8] { @@ -323,11 +330,7 @@ where assert_eq!(polys[ell - 1].len(), 2); // We do not need to commit to the first polynomial as it is already committed. - // Compute commitments in parallel - // TODO: This could be done by batch too if it gets progressively smaller. - let com: Vec = into_optimal_iter!(1..polys.len()) - .map(|i| UnivariateKZG::commit_as_univariate(&pk.kzg_pk, &polys[i]).unwrap()) - .collect(); + let com: Vec = UnivariateKZG::commit_variable_batch(&pk.kzg_pk, &polys[1..])?; // Phase 2 // We do not need to add x to the transcript, because in our context x was obtained from the transcript. @@ -407,7 +410,7 @@ where impl CommitmentScheme for HyperKZG where -

::ScalarField: field::JoltField, +

::ScalarField: JoltField,

::G1: Icicle, { type Field = P::ScalarField; @@ -439,11 +442,14 @@ where } #[tracing::instrument(skip_all, name = "HyperKZG::batch_commit")] - fn batch_commit( - polys: &[&MultilinearPolynomial], + fn batch_commit( + polys: &[U], gens: &Self::Setup, _batch_type: BatchType, - ) -> Vec { + ) -> Vec + where + U: Borrow> + Sync, + { UnivariateKZG::commit_batch(&gens.0.kzg_pk, polys) .unwrap() .into_par_iter() diff --git a/jolt-core/src/poly/commitment/kzg.rs b/jolt-core/src/poly/commitment/kzg.rs index 72187894e..7015ec633 100644 --- a/jolt-core/src/poly/commitment/kzg.rs +++ b/jolt-core/src/poly/commitment/kzg.rs @@ -9,6 +9,7 @@ use ark_ff::PrimeField; use ark_std::{One, UniformRand, Zero}; use rand_core::{CryptoRng, RngCore}; use rayon::prelude::*; +use std::borrow::Borrow; use std::marker::PhantomData; use std::sync::Arc; @@ -194,34 +195,33 @@ where P::G1: Icicle, { #[tracing::instrument(skip_all, name = "KZG::commit_batch")] - pub fn commit_batch( + pub fn commit_batch( pk: &KZGProverKey

, - polys: &[&MultilinearPolynomial], - ) -> Result, ProofVerifyError> { - Self::commit_batch_with_mode(pk, polys, CommitMode::Default) - } - - #[tracing::instrument(skip_all, name = "KZG::commit_batch_with_mode")] - pub fn commit_batch_with_mode( - pk: &KZGProverKey

, - polys: &[&MultilinearPolynomial], - _mode: CommitMode, - ) -> Result, ProofVerifyError> { + polys: &[U], + ) -> Result, ProofVerifyError> + where + U: Borrow> + Sync, + { let g1_powers = &pk.g1_powers(); let gpu_g1 = pk.gpu_g1(); // batch commit requires all batches to have the same length - assert!(polys.par_iter().all(|s| s.len() == polys[0].len())); - assert!(polys[0].len() <= g1_powers.len()); - - if let Some(invalid) = polys.iter().find(|coeffs| coeffs.len() > g1_powers.len()) { + assert!(polys + .par_iter() + .all(|s| s.borrow().len() == polys[0].borrow().len())); + assert!(polys[0].borrow().len() <= g1_powers.len()); + + if let Some(invalid) = polys + .iter() + .find(|coeffs| (*coeffs).borrow().len() > g1_powers.len()) + { return Err(ProofVerifyError::KeyLengthError( g1_powers.len(), - invalid.len(), + invalid.borrow().len(), )); } - let msm_size = polys[0].len(); + let msm_size = polys[0].borrow().len(); let commitments = ::batch_msm( &g1_powers[..msm_size], gpu_g1.map(|g| &g[..msm_size]), @@ -230,13 +230,58 @@ where Ok(commitments.into_iter().map(|c| c.into_affine()).collect()) } + // This API will try to minimize copies to the GPU or just do the batches in parallel on the CPU + #[tracing::instrument(skip_all, name = "KZG::commit_variable_batch")] + pub fn commit_variable_batch( + pk: &KZGProverKey

, + polys: &[MultilinearPolynomial], + ) -> Result, ProofVerifyError> { + let g1_powers = &pk.g1_powers(); + let gpu_g1 = pk.gpu_g1(); + + // batch commit requires all batches be less than the bases in size + if let Some(invalid) = polys.iter().find(|poly| poly.len() > g1_powers.len()) { + return Err(ProofVerifyError::KeyLengthError( + g1_powers.len(), + invalid.len(), + )); + } + + let commitments = ::variable_batch_msm(g1_powers, gpu_g1, polys); + Ok(commitments.into_iter().map(|c| c.into_affine()).collect()) + } + + #[tracing::instrument(skip_all, name = "KZG::commit_variable_batch_univariate")] + pub fn commit_variable_batch_univariate( + pk: &KZGProverKey

, + polys: &[UniPoly], + ) -> Result, ProofVerifyError> { + let g1_powers = &pk.g1_powers(); + let gpu_g1 = pk.gpu_g1(); + + // batch commit requires all batches be less than the bases in size + if let Some(invalid) = polys + .iter() + .find(|poly| poly.coeffs.len() > g1_powers.len()) + { + return Err(ProofVerifyError::KeyLengthError( + g1_powers.len(), + invalid.coeffs.len(), + )); + } + + let commitments = + ::variable_batch_msm_univariate(g1_powers, gpu_g1, polys); + Ok(commitments.into_iter().map(|c| c.into_affine()).collect()) + } + #[tracing::instrument(skip_all, name = "KZG::commit_offset")] pub fn commit_offset( pk: &KZGProverKey

, poly: &UniPoly, offset: usize, ) -> Result { - Self::commit_inner(pk, &poly.coeffs, offset, CommitMode::Default) + Self::commit_inner(pk, &poly.coeffs, offset) } #[tracing::instrument(skip_all, name = "KZG::commit")] @@ -244,16 +289,7 @@ where pk: &KZGProverKey

, poly: &UniPoly, ) -> Result { - Self::commit_inner(pk, &poly.coeffs, 0, CommitMode::Default) - } - - #[tracing::instrument(skip_all, name = "KZG::commit_with_mode")] - pub fn commit_with_mode( - pk: &KZGProverKey

, - poly: &UniPoly, - mode: CommitMode, - ) -> Result { - Self::commit_inner(pk, &poly.coeffs, 0, mode) + Self::commit_inner(pk, &poly.coeffs, 0) } #[tracing::instrument(skip_all, name = "KZG::commit_as_univariate")] @@ -283,7 +319,6 @@ where pk: &KZGProverKey

, coeffs: &[P::ScalarField], offset: usize, - _mode: CommitMode, ) -> Result { if pk.g1_powers().len() < coeffs.len() { return Err(ProofVerifyError::KeyLengthError( @@ -351,7 +386,7 @@ mod test { use rand_chacha::ChaCha20Rng; use rand_core::SeedableRng; - fn run_kzg_test(degree_generator: F, commit_mode: CommitMode) -> Result<(), ProofVerifyError> + fn run_kzg_test(degree_generator: F) -> Result<(), ProofVerifyError> where F: Fn(&mut ChaCha20Rng) -> usize, { @@ -363,7 +398,7 @@ mod test { let pp = Arc::new(SRS::::setup(&mut rng, degree, 2)); let (ck, vk) = SRS::trim(pp, degree); let p = UniPoly::random::(degree, rng); - let comm = UnivariateKZG::::commit_with_mode(&ck, &p, commit_mode)?; + let comm = UnivariateKZG::::commit(&ck, &p)?; let point = Fr::rand(rng); let (proof, value) = UnivariateKZG::::open(&ck, &p, &point)?; assert!( @@ -378,12 +413,6 @@ mod test { #[test] fn kzg_commit_prove_verify() -> Result<(), ProofVerifyError> { - run_kzg_test(|rng| rng.gen_range(2..20), CommitMode::Default) - } - - #[test] - fn kzg_commit_prove_verify_mode() -> Result<(), ProofVerifyError> { - // This test uses the grand product optimization and ensures only powers of 2 are used for degree generation - run_kzg_test(|rng| 1 << rng.gen_range(1..8), CommitMode::GrandProduct) + run_kzg_test(|rng| rng.gen_range(2..20)) } } diff --git a/jolt-core/src/poly/commitment/mock.rs b/jolt-core/src/poly/commitment/mock.rs index fd30aa55a..3661c8c8d 100644 --- a/jolt-core/src/poly/commitment/mock.rs +++ b/jolt-core/src/poly/commitment/mock.rs @@ -1,3 +1,4 @@ +use std::borrow::Borrow; use std::marker::PhantomData; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; @@ -49,14 +50,17 @@ where fn commit(poly: &MultilinearPolynomial, _setup: &Self::Setup) -> Self::Commitment { MockCommitment { poly: poly.clone() } } - fn batch_commit( - polys: &[&MultilinearPolynomial], + fn batch_commit

( + polys: &[P], setup: &Self::Setup, _batch_type: BatchType, - ) -> Vec { + ) -> Vec + where + P: Borrow>, + { polys .into_iter() - .map(|poly| Self::commit(poly, setup)) + .map(|poly| Self::commit(poly.borrow(), setup)) .collect() } fn prove( diff --git a/jolt-core/src/poly/commitment/zeromorph.rs b/jolt-core/src/poly/commitment/zeromorph.rs index 8e4aec857..fc8293d2a 100644 --- a/jolt-core/src/poly/commitment/zeromorph.rs +++ b/jolt-core/src/poly/commitment/zeromorph.rs @@ -1,8 +1,6 @@ #![allow(clippy::too_many_arguments)] #![allow(clippy::type_complexity)] -use std::{iter, marker::PhantomData}; - use crate::msm::{use_icicle, Icicle, VariableBaseMSM}; use crate::poly::multilinear_polynomial::{MultilinearPolynomial, PolynomialEvaluation}; use crate::poly::{dense_mlpoly::DensePolynomial, unipoly::UniPoly}; @@ -17,14 +15,15 @@ use ark_std::{One, Zero}; use itertools::izip; use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; use rand_core::{CryptoRng, RngCore}; +use std::borrow::Borrow; use std::sync::Arc; +use std::{iter, marker::PhantomData}; use super::{ commitment_scheme::{BatchType, CommitShape, CommitmentScheme}, kzg::{KZGProverKey, KZGVerifierKey, UnivariateKZG, SRS}, }; use crate::field::JoltField; -use crate::optimal_iter; use rayon::prelude::*; pub struct ZeromorphSRS(Arc>) @@ -290,10 +289,7 @@ where assert_eq!(quotients.len(), poly.get_num_vars()); assert_eq!(remainder, *eval); - // TODO(sagar): support variable_batch msms - or decide not to support them altogether - let q_k_com: Vec = optimal_iter!(quotients) - .map(|q| UnivariateKZG::commit(&pp.commit_pp, q).unwrap()) - .collect(); + let q_k_com = UnivariateKZG::commit_variable_batch_univariate(&pp.commit_pp, "ients)?; let q_comms: Vec = q_k_com.par_iter().map(|c| c.into_group()).collect(); // Compute the multilinear quotients q_k = q_k(X_0, ..., X_{k-1}) // let quotient_slices: Vec<&[P::ScalarField]> = @@ -456,11 +452,14 @@ where ZeromorphCommitment(UnivariateKZG::commit_as_univariate(&setup.0.commit_pp, poly).unwrap()) } - fn batch_commit( - polys: &[&MultilinearPolynomial], + fn batch_commit( + polys: &[U], gens: &Self::Setup, _batch_type: BatchType, - ) -> Vec { + ) -> Vec + where + U: Borrow> + Sync, + { UnivariateKZG::commit_batch(&gens.0.commit_pp, polys) .unwrap() .into_iter() diff --git a/jolt-core/src/poly/compact_polynomial.rs b/jolt-core/src/poly/compact_polynomial.rs index ac2c00565..f078942cb 100644 --- a/jolt-core/src/poly/compact_polynomial.rs +++ b/jolt-core/src/poly/compact_polynomial.rs @@ -1,5 +1,6 @@ use std::ops::Index; +use super::multilinear_polynomial::{BindingOrder, PolynomialBinding}; use crate::utils::math::Math; use crate::utils::thread::unsafe_allocate_zero_vec; use crate::{field::JoltField, utils}; @@ -9,8 +10,6 @@ use ark_serialize::{ use num_integer::Integer; use rayon::prelude::*; -use super::multilinear_polynomial::{BindingOrder, PolynomialBinding}; - pub trait SmallScalar: Copy + Integer + Sync { /// Performs a field multiplication. Uses `JoltField::mul_u64_unchecked` under the hood. /// WARNING: Does not convert the small scalar into Montgomery form before performing @@ -145,6 +144,10 @@ impl CompactPolynomial { pub fn iter(&self) -> impl Iterator { self.coeffs.iter() } + + pub fn coeffs_as_field_elements(&self) -> Vec { + self.coeffs.par_iter().map(|x| x.to_field()).collect() + } } impl PolynomialBinding for CompactPolynomial { diff --git a/jolt-core/src/poly/multilinear_polynomial.rs b/jolt-core/src/poly/multilinear_polynomial.rs index da08f68dd..3f3639745 100644 --- a/jolt-core/src/poly/multilinear_polynomial.rs +++ b/jolt-core/src/poly/multilinear_polynomial.rs @@ -78,30 +78,30 @@ impl MultilinearPolynomial { /// The maximum number of bits occupied by one of the polynomial's coefficients. #[tracing::instrument(skip_all)] - pub fn max_num_bits(&self) -> u32 { + pub fn max_num_bits(&self) -> usize { match self { MultilinearPolynomial::LargeScalars(poly) => poly .evals_ref() .par_iter() .map(|s| s.num_bits()) .max() - .unwrap(), + .unwrap() as usize, MultilinearPolynomial::U8Scalars(poly) => { - (*poly.coeffs.iter().max().unwrap() as usize).num_bits() as u32 + (*poly.coeffs.iter().max().unwrap() as usize).num_bits() } MultilinearPolynomial::U16Scalars(poly) => { - (*poly.coeffs.iter().max().unwrap() as usize).num_bits() as u32 + (*poly.coeffs.iter().max().unwrap() as usize).num_bits() } MultilinearPolynomial::U32Scalars(poly) => { - (*poly.coeffs.iter().max().unwrap() as usize).num_bits() as u32 + (*poly.coeffs.iter().max().unwrap() as usize).num_bits() } MultilinearPolynomial::U64Scalars(poly) => { - (*poly.coeffs.iter().max().unwrap() as usize).num_bits() as u32 + (*poly.coeffs.iter().max().unwrap() as usize).num_bits() } MultilinearPolynomial::I64Scalars(_) => { // HACK(moodlezoup): i64 coefficients are converted into full-width field // elements before computing the MSM - F::NUM_BYTES as u32 * 8 + F::NUM_BYTES * 8 } } } @@ -632,3 +632,100 @@ impl PolynomialEvaluation for MultilinearPolynomial { } } } + +#[cfg(test)] +mod tests { + use super::*; + use ark_bn254::Fr; + use rand_chacha::ChaCha20Rng; + use rand_core::{RngCore, SeedableRng}; + + fn random_poly(max_num_bits: usize, len: usize) -> MultilinearPolynomial { + let mut rng = ChaCha20Rng::seed_from_u64(len as u64); + match max_num_bits { + 0 => MultilinearPolynomial::from(vec![0u8; len]), + 1..=8 => MultilinearPolynomial::from( + (0..len) + .map(|_| { + let mask = if max_num_bits == 8 { + u8::MAX + } else { + (1u8 << max_num_bits) - 1 + }; + (rng.next_u32() & (mask as u32)) as u8 + }) + .collect::>(), + ), + 9..=16 => MultilinearPolynomial::from( + (0..len) + .map(|_| { + let mask = if max_num_bits == 16 { + u16::MAX + } else { + (1u16 << max_num_bits) - 1 + }; + (rng.next_u32() & (mask as u32)) as u16 + }) + .collect::>(), + ), + 17..=32 => MultilinearPolynomial::from( + (0..len) + .map(|_| { + let mask = if max_num_bits == 32 { + u32::MAX + } else { + (1u32 << max_num_bits) - 1 + }; + (rng.next_u64() & (mask as u64)) as u32 + }) + .collect::>(), + ), + 33..=64 => MultilinearPolynomial::from( + (0..len) + .map(|_| { + let mask = if max_num_bits == 64 { + u64::MAX + } else { + (1u64 << max_num_bits) - 1 + }; + rng.next_u64() & mask + }) + .collect::>(), + ), + _ => MultilinearPolynomial::from( + (0..len).map(|_| Fr::random(&mut rng)).collect::>(), + ), + } + } + + #[test] + fn test_poly_to_field_elements() { + let small_value_lookup_tables = ::compute_lookup_tables(); + ::initialize_lookup_tables(small_value_lookup_tables); + + let max_num_bits = [ + vec![8; 100], + vec![16; 100], + vec![32; 100], + vec![64; 100], + vec![256; 300], + ] + .concat(); + + for &max_num_bits in max_num_bits.iter() { + let len = 1 << 2; + let poly = random_poly(max_num_bits, len); + let field_elements: Vec = match poly { + MultilinearPolynomial::U8Scalars(poly) => poly.coeffs_as_field_elements(), + MultilinearPolynomial::U16Scalars(poly) => poly.coeffs_as_field_elements(), + MultilinearPolynomial::U32Scalars(poly) => poly.coeffs_as_field_elements(), + MultilinearPolynomial::U64Scalars(poly) => poly.coeffs_as_field_elements(), + MultilinearPolynomial::LargeScalars(poly) => poly.evals(), + _ => { + panic!("Unexpected MultilinearPolynomial variant"); + } + }; + assert_eq!(field_elements.len(), len); + } + } +} diff --git a/jolt-core/src/utils/transcript.rs b/jolt-core/src/utils/transcript.rs index 7b2349bc3..a2067e5f4 100644 --- a/jolt-core/src/utils/transcript.rs +++ b/jolt-core/src/utils/transcript.rs @@ -2,6 +2,7 @@ use crate::field::JoltField; use ark_ec::{AffineRepr, CurveGroup}; use ark_serialize::CanonicalSerialize; use sha3::{Digest, Keccak256}; +use std::borrow::Borrow; /// Represents the current state of the protocol's Fiat-Shamir transcript. #[derive(Clone)] @@ -142,10 +143,10 @@ impl Transcript for KeccakTranscript { self.append_bytes(&buf); } - fn append_scalars(&mut self, scalars: &[F]) { + fn append_scalars(&mut self, scalars: &[impl Borrow]) { self.append_message(b"begin_append_vector"); for item in scalars.iter() { - self.append_scalar(item); + self.append_scalar(item.borrow()); } self.append_message(b"end_append_vector"); } @@ -215,7 +216,7 @@ pub trait Transcript: Clone + Sync + Send + 'static { fn append_bytes(&mut self, bytes: &[u8]); fn append_u64(&mut self, x: u64); fn append_scalar(&mut self, scalar: &F); - fn append_scalars(&mut self, scalars: &[F]); + fn append_scalars(&mut self, scalars: &[impl Borrow]); fn append_point(&mut self, point: &G); fn append_points(&mut self, points: &[G]); fn challenge_scalar(&mut self) -> F;