Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: optimize opening proof gpu usage #569

Merged
merged 17 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 13 additions & 6 deletions jolt-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -124,16 +124,23 @@ name = "jolt_core"
path = "src/lib.rs"

[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
memory-stats = "1.0.0"
sys-info = "0.9.1"
tokio = { version = "1.38.0", optional = true, features = ["rt-multi-thread"] }

[target.'cfg(all(not(target_arch = "wasm32"), target_os = "macos"))'.dependencies]
icicle-core = { git = "https://github.com/ingonyama-zk/icicle-jolt.git", rev = "94fe8ca", optional = true }
icicle-runtime = { git = "https://github.com/ingonyama-zk/icicle-jolt.git", rev = "94fe8ca", optional = true }
icicle-bn254 = { git = "https://github.com/ingonyama-zk/icicle-jolt.git", rev = "94fe8ca", optional = true }

[target.'cfg(all(not(target_arch = "wasm32"), not(target_os = "macos")))'.dependencies]
icicle-core = { git = "https://github.com/ingonyama-zk/icicle-jolt.git", rev = "94fe8ca", optional = true }
icicle-runtime = { git = "https://github.com/ingonyama-zk/icicle-jolt.git", features = [
"cuda_backend",
], rev = "ed93e21", optional = true }
icicle-core = { git = "https://github.com/ingonyama-zk/icicle-jolt.git", rev = "ed93e21", optional = true }
], rev = "94fe8ca", optional = true }
icicle-bn254 = { git = "https://github.com/ingonyama-zk/icicle-jolt.git", features = [
"cuda_backend",
], rev = "ed93e21", optional = true }
memory-stats = "1.0.0"
sys-info = "0.9.1"
tokio = { version = "1.38.0", optional = true, features = ["rt-multi-thread"] }
], rev = "94fe8ca", optional = true }

[target.'cfg(target_arch = "wasm32")'.dependencies]
getrandom = { version = "0.2", features = ["js"] }
6 changes: 1 addition & 5 deletions jolt-core/benches/commit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,7 @@ fn benchmark_commit<PCS, F, ProofTranscript>(
&format!("{} Commit(mode:{:?}): {}% Ones", name, mode, threshold),
|b| {
b.iter(|| {
PCS::batch_commit(
&leaves.iter().collect::<Vec<_>>(),
&setup,
batch_type.clone(),
);
PCS::batch_commit(&leaves, &setup, batch_type.clone());
});
},
);
Expand Down
3 changes: 3 additions & 0 deletions jolt-core/benches/msm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ where
}

fn main() {
let small_value_lookup_tables = <Fr as JoltField>::compute_lookup_tables();
<Fr as JoltField>::initialize_lookup_tables(small_value_lookup_tables);

let mut criterion = Criterion::default()
.configure_from_args()
.sample_size(20)
Expand Down
52 changes: 38 additions & 14 deletions jolt-core/benches/msm_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,34 +60,55 @@ fn random_poly(max_num_bits: usize, len: usize) -> MultilinearPolynomial<Fr> {
0 => MultilinearPolynomial::from(vec![0u8; len]),
1..=8 => MultilinearPolynomial::from(
(0..len)
.into_iter()
.map(|_| (rng.next_u32() & ((1 << max_num_bits) - 1)) as u8)
.map(|_| {
let mask = if max_num_bits == 8 {
u8::MAX
} else {
(1u8 << max_num_bits) - 1
};
(rng.next_u32() & (mask as u32)) as u8
})
.collect::<Vec<_>>(),
),
9..=16 => MultilinearPolynomial::from(
(0..len)
.into_iter()
.map(|_| (rng.next_u32() & ((1 << max_num_bits) - 1)) as u16)
.map(|_| {
let mask = if max_num_bits == 16 {
u16::MAX
} else {
(1u16 << max_num_bits) - 1
};
(rng.next_u32() & (mask as u32)) as u16
})
.collect::<Vec<_>>(),
),
17..=32 => MultilinearPolynomial::from(
(0..len)
.into_iter()
.map(|_| (rng.next_u64() & ((1 << max_num_bits) - 1)) as u32)
.map(|_| {
let mask = if max_num_bits == 32 {
u32::MAX
} else {
(1u32 << max_num_bits) - 1
};
(rng.next_u64() & (mask as u64)) as u32
})
.collect::<Vec<_>>(),
),
33..=64 => MultilinearPolynomial::from(
(0..len)
.into_iter()
.map(|_| rng.next_u64() & ((1 << max_num_bits) - 1))
.collect::<Vec<_>>(),
),
_ => MultilinearPolynomial::from(
(0..len)
.into_iter()
.map(|_| Fr::random(&mut rng))
.map(|_| {
let mask = if max_num_bits == 64 {
u64::MAX
} else {
(1u64 << max_num_bits) - 1
};
rng.next_u64() & mask
})
.collect::<Vec<_>>(),
),
_ => {
MultilinearPolynomial::from((0..len).map(|_| Fr::random(&mut rng)).collect::<Vec<_>>())
}
}
}

Expand Down Expand Up @@ -120,6 +141,9 @@ fn benchmark_msm_batch<PCS, F, ProofTranscript>(
}

fn main() {
let small_value_lookup_tables = <Fr as JoltField>::compute_lookup_tables();
<Fr as JoltField>::initialize_lookup_tables(small_value_lookup_tables);

let mut criterion = Criterion::default()
.configure_from_args()
.sample_size(10)
Expand Down
6 changes: 1 addition & 5 deletions jolt-core/src/jolt/vm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -276,11 +276,7 @@ impl<F: JoltField> JoltPolynomials<F> {
|| PCS::commit(&self.read_write_memory.t_final, &preprocessing.generators)
);
commitments.instruction_lookups.final_cts = PCS::batch_commit(
&self
.instruction_lookups
.final_cts
.iter()
.collect::<Vec<_>>(),
&self.instruction_lookups.final_cts,
&preprocessing.generators,
BatchType::Big,
);
Expand Down
2 changes: 1 addition & 1 deletion jolt-core/src/lasso/surge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ where
.zip(trace_comitments.into_iter())
.for_each(|(dest, src)| *dest = src);
commitments.final_cts = PCS::batch_commit(
&polynomials.final_cts.iter().collect::<Vec<_>>(),
&polynomials.final_cts,
generators,
BatchType::SurgeInitFinal,
);
Expand Down
91 changes: 91 additions & 0 deletions jolt-core/src/msm/icicle/adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,97 @@ where
V::to_ark_projective(&msm_host_result[0])
}

// batch_info is a tuple of (batch_id, bit_size, scalars)
pub type BatchInfo<'a, V: VariableBaseMSM> = (usize, usize, &'a [V::ScalarField]);

#[tracing::instrument(skip_all, name = "icicle_variable_batch_msm")]
pub fn icicle_variable_batch_msm<V>(
bases: &[GpuBaseType<V>],
batch_info: &[BatchInfo<V>],
) -> Vec<(usize, V)>
where
V: VariableBaseMSM,
V::ScalarField: JoltField,
{
let mut stream = IcicleStream::create().unwrap();

let mut bases_slice =
DeviceVec::<GpuBaseType<V>>::device_malloc_async(bases.len(), &stream).unwrap();
let span = tracing::span!(tracing::Level::INFO, "copy_bases_to_gpu");
let _guard = span.enter();
bases_slice
.copy_from_host_async(HostSlice::from_slice(bases), &stream)
.unwrap();
drop(_guard);
drop(span);

let num_batches = batch_info.len();
let mut msm_result =
DeviceVec::<Projective<V::C>>::device_malloc_async(num_batches, &stream).unwrap();
let mut msm_host_results = vec![Projective::<V::C>::zero(); num_batches];

for (index, (_batch_id, bit_size, scalars)) in batch_info.iter().enumerate() {
let span = tracing::span!(tracing::Level::INFO, "convert_scalars");
let _guard = span.enter();

let mut scalars_slice =
DeviceVec::<<<V as Icicle>::C as Curve>::ScalarField>::device_malloc_async(
scalars.len(),
&stream,
)
.unwrap();
let scalars_mont = unsafe {
&*(&scalars[..] as *const _ as *const [<<V as Icicle>::C as Curve>::ScalarField])
};
drop(_guard);
drop(span);

let span = tracing::span!(tracing::Level::INFO, "copy_scalars_to_gpu");
let _guard = span.enter();
scalars_slice
.copy_from_host_async(HostSlice::from_slice(scalars_mont), &stream)
.unwrap();
drop(_guard);
drop(span);

let mut cfg = MSMConfig::default();
cfg.stream_handle = IcicleStreamHandle::from(&stream);
cfg.is_async = true;
cfg.are_scalars_montgomery_form = true;
cfg.bitsize = *bit_size as i32;

let span = tracing::span!(tracing::Level::INFO, "gpu_msm");
let _guard = span.enter();

msm(
&scalars_slice,
&bases_slice[..scalars.len()],
&cfg,
&mut msm_result[index..index + 1],
)
.unwrap();

drop(_guard);
drop(span);
}

let span = tracing::span!(tracing::Level::INFO, "copy_msm_result");
let _guard = span.enter();
msm_result
.copy_to_host(HostSlice::from_mut_slice(&mut msm_host_results))
.unwrap();
drop(_guard);
drop(span);

stream.synchronize().unwrap();
stream.destroy().unwrap();
batch_info
.par_iter()
.zip(msm_host_results)
.map(|((batch_id, _, _), result)| (*batch_id, V::to_ark_projective(&result)))
.collect()
}

/// Batch process msms - assumes batches are equal in size
/// Variable Batch sizes is not currently supported by icicle
#[tracing::instrument(skip_all)]
Expand Down
Loading