diff --git a/Cargo.toml b/Cargo.toml index a8372f7..2af426f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,7 @@ ark-ff = { version = "0.4.0" } ark-serialize = { version = "0.4.2" } ark-std = { version = "0.4.0" } rayon = { version = "1.5" } -blitzar-sys = { version = "1.78.0" } +blitzar-sys = { version = "1.81.0" } curve25519-dalek = { version = "4", features = ["serde"] } merlin = "2" serde = { version = "1", features = ["serde_derive"] } @@ -33,6 +33,7 @@ criterion = { version = "0.3", features = ["html_reports"] } curve25519-dalek = { version = "4", features = ["rand_core"] } rand = "0.8" rand_core = "0.6" +tempfile = "3.13.0" [[bench]] harness = false diff --git a/src/compute/fixed_msm.rs b/src/compute/fixed_msm.rs index cc2ce80..c5049cd 100644 --- a/src/compute/fixed_msm.rs +++ b/src/compute/fixed_msm.rs @@ -2,7 +2,7 @@ use super::backend::init_backend; use crate::compute::{curve::SwCurveConfig, CurveId, ElementP2}; use ark_ec::short_weierstrass::Affine; use rayon::prelude::*; -use std::marker::PhantomData; +use std::{ffi::CString, marker::PhantomData}; fn count_scalars_per_output(scalars_len: usize, output_bit_table: &[u32]) -> u32 { let bit_sum: usize = output_bit_table.iter().map(|s| *s as usize).sum(); @@ -46,6 +46,34 @@ impl MsmHandle { } } + /// New handle from a serialized file. + /// + /// Note: any MSMs computed with the handle must have length less than or equal + /// to the number of generators used to create the handle. + pub fn new_from_file(filename: &str) -> Self { + init_backend(); + let filename = CString::new(filename).expect("filename cannot have null bytes"); + unsafe { + let handle = + blitzar_sys::sxt_multiexp_handle_new_from_file(T::CURVE_ID, filename.as_ptr()); + Self { + handle, + phantom: PhantomData, + } + } + } + + /// Serialize the handle to a file. + /// + /// This function can be used together with new_from_file to reduce + /// the cost of creating a handle. + pub fn write(&self, filename: &str) { + let filename = CString::new(filename).expect("filename cannot have null bytes"); + unsafe { + blitzar_sys::sxt_multiexp_handle_write_to_file(self.handle, filename.as_ptr()); + } + } + /// Compute an MSM using pre-specified generators. /// /// Suppose g_1, ..., g_n are pre-specified generators and diff --git a/src/compute/fixed_msm_tests.rs b/src/compute/fixed_msm_tests.rs index 17d4778..9ebe175 100644 --- a/src/compute/fixed_msm_tests.rs +++ b/src/compute/fixed_msm_tests.rs @@ -4,6 +4,7 @@ use ark_bls12_381::G1Affine; use ark_std::UniformRand; use curve25519_dalek::ristretto::RistrettoPoint; use rand_core::OsRng; +use tempfile::TempDir; #[test] fn we_can_compute_msms_using_a_single_generator() { @@ -48,6 +49,33 @@ fn we_can_compute_msms_using_multiple_generator() { assert_eq!(res[0], generators[0] + generators[1] + generators[1]); } +#[test] +fn we_can_serialize_a_handle_to_a_file() { + let mut rng = OsRng; + + let mut res = vec![RistrettoPoint::default(); 1]; + + // randomly obtain the generator points + let generators: Vec = + (0..2).map(|_| RistrettoPoint::random(&mut rng)).collect(); + + // create handle + let handle = MsmHandle::new(&generators); + + // write the handle to a file + let tmp_dir = TempDir::new().unwrap(); + let filename = tmp_dir.path().join("t").to_str().unwrap().to_string(); + handle.write(&filename); + + // read the handle back from file + let handle = MsmHandle::::new_from_file(&filename); + + // we can compute a multiexponentiation + let scalars: Vec = vec![1, 2]; + handle.msm(&mut res, 1, &scalars); + assert_eq!(res[0], generators[0] + generators[1] + generators[1]); +} + #[test] fn we_can_compute_msms_using_multiple_outputs() { let mut rng = OsRng;