From 1041c7cecbbbc00bd6fcacade9d816bd55b4b951 Mon Sep 17 00:00:00 2001 From: th4s Date: Wed, 15 Jan 2025 12:50:15 +0100 Subject: [PATCH 1/4] refactor(prf): adapt prf to new mpz vm Co-authored-by: sinu <65924192+sinui0@users.noreply.github.com> --- crates/components/hmac-sha256/Cargo.toml | 12 +- crates/components/hmac-sha256/benches/prf.rs | 311 ++++++------ crates/components/hmac-sha256/src/error.rs | 26 +- crates/components/hmac-sha256/src/lib.rs | 287 +++++------ crates/components/hmac-sha256/src/prf.rs | 492 ++++++------------- 5 files changed, 482 insertions(+), 646 deletions(-) diff --git a/crates/components/hmac-sha256/Cargo.toml b/crates/components/hmac-sha256/Cargo.toml index b9cb86bbe..d19ce0495 100644 --- a/crates/components/hmac-sha256/Cargo.toml +++ b/crates/components/hmac-sha256/Cargo.toml @@ -19,21 +19,23 @@ mock = [] [dependencies] tlsn-hmac-sha256-circuits = { workspace = true } -mpz-garble = { workspace = true } +mpz-vm-core = { workspace = true } mpz-circuits = { workspace = true } mpz-common = { workspace = true } -async-trait = { workspace = true } derive_builder = { workspace = true } -futures = { workspace = true } thiserror = { workspace = true } tracing = { workspace = true } +futures = { workspace = true } [dev-dependencies] -criterion = { workspace = true, features = ["async_tokio"] } -mpz-common = { workspace = true, features = ["test-utils"] } mpz-ot = { workspace = true, features = ["ideal"] } +mpz-garble = { workspace = true } +mpz-common = { workspace = true, features = ["test-utils"] } + +criterion = { workspace = true, features = ["async_tokio"] } tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] } +rand = { workspace = true } [[bench]] name = "prf" diff --git a/crates/components/hmac-sha256/benches/prf.rs b/crates/components/hmac-sha256/benches/prf.rs index 1a31aba39..6b826c4bf 100644 --- a/crates/components/hmac-sha256/benches/prf.rs +++ b/crates/components/hmac-sha256/benches/prf.rs @@ -1,9 +1,14 @@ use criterion::{criterion_group, criterion_main, Criterion}; use hmac_sha256::{MpcPrf, Prf, PrfConfig, Role}; -use mpz_common::executor::test_mt_executor; -use mpz_garble::{config::Role as DEAPRole, protocol::deap::DEAPThread, Memory}; -use mpz_ot::ideal::ot::ideal_ot; +use mpz_common::executor::{mt::MTConfig, test_mt_executor}; +use mpz_garble::protocol::semihonest::{Evaluator, Generator}; +use mpz_ot::ideal::cot::ideal_cot; +use mpz_vm_core::{ + memory::{binary::U8, correlated::Delta, Array}, + prelude::*, +}; +use rand::{rngs::StdRng, SeedableRng}; #[allow(clippy::unit_arg)] fn criterion_benchmark(c: &mut Criterion) { @@ -11,178 +16,194 @@ fn criterion_benchmark(c: &mut Criterion) { group.sample_size(10); let rt = tokio::runtime::Runtime::new().unwrap(); - group.bench_function("prf_preprocess", |b| b.to_async(&rt).iter(preprocess)); + //group.bench_function("prf_preprocess", |b| b.to_async(&rt).iter(preprocess)); group.bench_function("prf", |b| b.to_async(&rt).iter(prf)); } criterion_group!(benches, criterion_benchmark); criterion_main!(benches); -async fn preprocess() { - let (mut leader_exec, mut follower_exec) = test_mt_executor(128); +// async fn preprocess() { +// let (mut leader_exec, mut follower_exec) = test_mt_executor(128); + +// let (leader_ot_send_0, follower_ot_recv_0) = ideal_ot(); +// let (follower_ot_send_0, leader_ot_recv_0) = ideal_ot(); +// let (leader_ot_send_1, follower_ot_recv_1) = ideal_ot(); +// let (follower_ot_send_1, leader_ot_recv_1) = ideal_ot(); + +// let leader_thread_0 = DEAPThread::new( +// DEAPRole::Leader, +// [0u8; 32], +// leader_exec.new_thread().await.unwrap(), +// leader_ot_send_0, +// leader_ot_recv_0, +// ); +// let leader_thread_1 = leader_thread_0 +// .new_thread( +// leader_exec.new_thread().await.unwrap(), +// leader_ot_send_1, +// leader_ot_recv_1, +// ) +// .unwrap(); + +// let follower_thread_0 = DEAPThread::new( +// DEAPRole::Follower, +// [0u8; 32], +// follower_exec.new_thread().await.unwrap(), +// follower_ot_send_0, +// follower_ot_recv_0, +// ); +// let follower_thread_1 = follower_thread_0 +// .new_thread( +// follower_exec.new_thread().await.unwrap(), +// follower_ot_send_1, +// follower_ot_recv_1, +// ) +// .unwrap(); + +// let leader_pms = leader_thread_0.new_public_input::<[u8; +// 32]>("pms").unwrap(); let follower_pms = follower_thread_0 +// .new_public_input::<[u8; 32]>("pms") +// .unwrap(); + +// let mut leader = MpcPrf::new( +// PrfConfig::builder().role(Role::Leader).build().unwrap(), +// leader_thread_0, +// leader_thread_1, +// ); +// let mut follower = MpcPrf::new( +// PrfConfig::builder().role(Role::Follower).build().unwrap(), +// follower_thread_0, +// follower_thread_1, +// ); + +// futures::join!( +// async { +// leader.setup(leader_pms).await.unwrap(); +// leader.set_client_random(Some([0u8; 32])).await.unwrap(); +// leader.preprocess().await.unwrap(); +// }, +// async { +// follower.setup(follower_pms).await.unwrap(); +// follower.set_client_random(None).await.unwrap(); +// follower.preprocess().await.unwrap(); +// } +// ); +// } - let (leader_ot_send_0, follower_ot_recv_0) = ideal_ot(); - let (follower_ot_send_0, leader_ot_recv_0) = ideal_ot(); - let (leader_ot_send_1, follower_ot_recv_1) = ideal_ot(); - let (follower_ot_send_1, leader_ot_recv_1) = ideal_ot(); +async fn prf() { + let mut rng = StdRng::seed_from_u64(0); - let leader_thread_0 = DEAPThread::new( - DEAPRole::Leader, - [0u8; 32], - leader_exec.new_thread().await.unwrap(), - leader_ot_send_0, - leader_ot_recv_0, - ); - let leader_thread_1 = leader_thread_0 - .new_thread( - leader_exec.new_thread().await.unwrap(), - leader_ot_send_1, - leader_ot_recv_1, - ) - .unwrap(); + let pms = [42u8; 32]; + let client_random = [69u8; 32]; + let server_random: [u8; 32] = [96u8; 32]; - let follower_thread_0 = DEAPThread::new( - DEAPRole::Follower, - [0u8; 32], - follower_exec.new_thread().await.unwrap(), - follower_ot_send_0, - follower_ot_recv_0, - ); - let follower_thread_1 = follower_thread_0 - .new_thread( - follower_exec.new_thread().await.unwrap(), - follower_ot_send_1, - follower_ot_recv_1, - ) + let (mut leader_exec, mut follower_exec) = test_mt_executor(128, MTConfig::default()); + let mut leader_ctx = leader_exec.new_thread().await.unwrap(); + let mut follower_ctx = follower_exec.new_thread().await.unwrap(); + + let delta = Delta::random(&mut rng); + let (ot_send, ot_recv) = ideal_cot(delta.into_inner()); + + let mut leader_vm = Generator::new(ot_send, [0u8; 16], delta); + let mut follower_vm = Evaluator::new(ot_recv); + + let leader_pms: Array = leader_vm.alloc().unwrap(); + leader_vm.mark_public(leader_pms).unwrap(); + leader_vm.assign(leader_pms, pms).unwrap(); + leader_vm.commit(leader_pms).unwrap(); + + let follower_pms: Array = follower_vm.alloc().unwrap(); + follower_vm.mark_public(follower_pms).unwrap(); + follower_vm.assign(follower_pms, pms).unwrap(); + follower_vm.commit(follower_pms).unwrap(); + + let mut leader = MpcPrf::new(PrfConfig::builder().role(Role::Leader).build().unwrap()); + let mut follower = MpcPrf::new(PrfConfig::builder().role(Role::Follower).build().unwrap()); + + let leader_output = leader.setup(&mut leader_vm, leader_pms).unwrap(); + let follower_output = follower.setup(&mut follower_vm, follower_pms).unwrap(); + + leader + .set_client_random(&mut leader_vm, Some(client_random)) .unwrap(); + follower.set_client_random(&mut follower_vm, None).unwrap(); - let leader_pms = leader_thread_0.new_public_input::<[u8; 32]>("pms").unwrap(); - let follower_pms = follower_thread_0 - .new_public_input::<[u8; 32]>("pms") + leader + .set_server_random(&mut leader_vm, server_random) + .unwrap(); + follower + .set_server_random(&mut follower_vm, server_random) .unwrap(); - let mut leader = MpcPrf::new( - PrfConfig::builder().role(Role::Leader).build().unwrap(), - leader_thread_0, - leader_thread_1, - ); - let mut follower = MpcPrf::new( - PrfConfig::builder().role(Role::Follower).build().unwrap(), - follower_thread_0, - follower_thread_1, - ); + #[allow(clippy::let_underscore_future)] + let _ = leader_vm + .decode(leader_output.keys.client_write_key) + .unwrap(); + #[allow(clippy::let_underscore_future)] + let _ = leader_vm + .decode(leader_output.keys.server_write_key) + .unwrap(); + #[allow(clippy::let_underscore_future)] + let _ = leader_vm.decode(leader_output.keys.client_iv).unwrap(); + #[allow(clippy::let_underscore_future)] + let _ = leader_vm.decode(leader_output.keys.server_iv).unwrap(); + + #[allow(clippy::let_underscore_future)] + let _ = follower_vm + .decode(follower_output.keys.client_write_key) + .unwrap(); + #[allow(clippy::let_underscore_future)] + let _ = follower_vm + .decode(follower_output.keys.server_write_key) + .unwrap(); + #[allow(clippy::let_underscore_future)] + let _ = follower_vm.decode(follower_output.keys.client_iv).unwrap(); + #[allow(clippy::let_underscore_future)] + let _ = follower_vm.decode(follower_output.keys.server_iv).unwrap(); futures::join!( async { - leader.setup(leader_pms).await.unwrap(); - leader.set_client_random(Some([0u8; 32])).await.unwrap(); - leader.preprocess().await.unwrap(); + leader_vm.flush(&mut leader_ctx).await.unwrap(); + leader_vm.execute(&mut leader_ctx).await.unwrap(); + leader_vm.flush(&mut leader_ctx).await.unwrap(); }, async { - follower.setup(follower_pms).await.unwrap(); - follower.set_client_random(None).await.unwrap(); - follower.preprocess().await.unwrap(); + follower_vm.flush(&mut follower_ctx).await.unwrap(); + follower_vm.execute(&mut follower_ctx).await.unwrap(); + follower_vm.flush(&mut follower_ctx).await.unwrap(); } ); -} -async fn prf() { - let (mut leader_exec, mut follower_exec) = test_mt_executor(128); - - let (leader_ot_send_0, follower_ot_recv_0) = ideal_ot(); - let (follower_ot_send_0, leader_ot_recv_0) = ideal_ot(); - let (leader_ot_send_1, follower_ot_recv_1) = ideal_ot(); - let (follower_ot_send_1, leader_ot_recv_1) = ideal_ot(); - - let leader_thread_0 = DEAPThread::new( - DEAPRole::Leader, - [0u8; 32], - leader_exec.new_thread().await.unwrap(), - leader_ot_send_0, - leader_ot_recv_0, - ); - let leader_thread_1 = leader_thread_0 - .new_thread( - leader_exec.new_thread().await.unwrap(), - leader_ot_send_1, - leader_ot_recv_1, - ) - .unwrap(); + let cf_hs_hash = [1u8; 32]; + let sf_hs_hash = [2u8; 32]; - let follower_thread_0 = DEAPThread::new( - DEAPRole::Follower, - [0u8; 32], - follower_exec.new_thread().await.unwrap(), - follower_ot_send_0, - follower_ot_recv_0, - ); - let follower_thread_1 = follower_thread_0 - .new_thread( - follower_exec.new_thread().await.unwrap(), - follower_ot_send_1, - follower_ot_recv_1, - ) - .unwrap(); + leader.set_cf_hash(&mut leader_vm, cf_hs_hash).unwrap(); + leader.set_sf_hash(&mut leader_vm, sf_hs_hash).unwrap(); - let leader_pms = leader_thread_0.new_public_input::<[u8; 32]>("pms").unwrap(); - let follower_pms = follower_thread_0 - .new_public_input::<[u8; 32]>("pms") - .unwrap(); + follower.set_cf_hash(&mut follower_vm, cf_hs_hash).unwrap(); + follower.set_sf_hash(&mut follower_vm, sf_hs_hash).unwrap(); - let mut leader = MpcPrf::new( - PrfConfig::builder().role(Role::Leader).build().unwrap(), - leader_thread_0, - leader_thread_1, - ); - let mut follower = MpcPrf::new( - PrfConfig::builder().role(Role::Follower).build().unwrap(), - follower_thread_0, - follower_thread_1, - ); + #[allow(clippy::let_underscore_future)] + let _ = leader_vm.decode(leader_output.cf_vd).unwrap(); + #[allow(clippy::let_underscore_future)] + let _ = leader_vm.decode(leader_output.sf_vd).unwrap(); - let pms = [42u8; 32]; - let client_random = [0u8; 32]; - let server_random = [1u8; 32]; - let cf_hs_hash = [2u8; 32]; - let sf_hs_hash = [3u8; 32]; + #[allow(clippy::let_underscore_future)] + let _ = follower_vm.decode(follower_output.cf_vd).unwrap(); + #[allow(clippy::let_underscore_future)] + let _ = follower_vm.decode(follower_output.sf_vd).unwrap(); futures::join!( async { - leader.setup(leader_pms.clone()).await.unwrap(); - leader.set_client_random(Some(client_random)).await.unwrap(); - leader.preprocess().await.unwrap(); + leader_vm.flush(&mut leader_ctx).await.unwrap(); + leader_vm.execute(&mut leader_ctx).await.unwrap(); + leader_vm.flush(&mut leader_ctx).await.unwrap(); }, async { - follower.setup(follower_pms.clone()).await.unwrap(); - follower.set_client_random(None).await.unwrap(); - follower.preprocess().await.unwrap(); + follower_vm.flush(&mut follower_ctx).await.unwrap(); + follower_vm.execute(&mut follower_ctx).await.unwrap(); + follower_vm.flush(&mut follower_ctx).await.unwrap(); } ); - - leader.thread_mut().assign(&leader_pms, pms).unwrap(); - follower.thread_mut().assign(&follower_pms, pms).unwrap(); - - let (_leader_keys, _follower_keys) = futures::try_join!( - leader.compute_session_keys(server_random), - follower.compute_session_keys(server_random) - ) - .unwrap(); - - let _ = futures::try_join!( - leader.compute_client_finished_vd(cf_hs_hash), - follower.compute_client_finished_vd(cf_hs_hash) - ) - .unwrap(); - - let _ = futures::try_join!( - leader.compute_server_finished_vd(sf_hs_hash), - follower.compute_server_finished_vd(sf_hs_hash) - ) - .unwrap(); - - futures::try_join!( - leader.thread_mut().finalize(), - follower.thread_mut().finalize() - ) - .unwrap(); } diff --git a/crates/components/hmac-sha256/src/error.rs b/crates/components/hmac-sha256/src/error.rs index ec8163825..d22f75494 100644 --- a/crates/components/hmac-sha256/src/error.rs +++ b/crates/components/hmac-sha256/src/error.rs @@ -33,6 +33,10 @@ impl PrfError { source: Some(msg.into().into()), } } + + pub(crate) fn vm>>(err: E) -> Self { + Self::new(ErrorKind::Vm, err) + } } #[derive(Debug)] @@ -58,26 +62,8 @@ impl fmt::Display for PrfError { } } -impl From for PrfError { - fn from(error: mpz_garble::MemoryError) -> Self { - Self::new(ErrorKind::Vm, error) - } -} - -impl From for PrfError { - fn from(error: mpz_garble::LoadError) -> Self { - Self::new(ErrorKind::Vm, error) - } -} - -impl From for PrfError { - fn from(error: mpz_garble::ExecutionError) -> Self { - Self::new(ErrorKind::Vm, error) - } -} - -impl From for PrfError { - fn from(error: mpz_garble::DecodeError) -> Self { +impl From for PrfError { + fn from(error: mpz_common::ContextError) -> Self { Self::new(ErrorKind::Vm, error) } } diff --git a/crates/components/hmac-sha256/src/lib.rs b/crates/components/hmac-sha256/src/lib.rs index 6550463cd..ad0d21e80 100644 --- a/crates/components/hmac-sha256/src/lib.rs +++ b/crates/components/hmac-sha256/src/lib.rs @@ -12,84 +12,101 @@ pub use config::{PrfConfig, PrfConfigBuilder, PrfConfigBuilderError, Role}; pub use error::PrfError; pub use prf::MpcPrf; -use async_trait::async_trait; - -use mpz_garble::value::ValueRef; +use mpz_vm_core::memory::{binary::U8, Array}; pub(crate) static CF_LABEL: &[u8] = b"client finished"; pub(crate) static SF_LABEL: &[u8] = b"server finished"; +/// Builds the circuits for the PRF. +/// +/// This function can be used ahead of time to build the circuits for the PRF, +/// which at the moment is CPU and memory intensive. +pub async fn build_circuits() { + prf::Circuits::get().await; +} + +/// PRF output. +#[derive(Debug, Clone, Copy)] +pub struct PrfOutput { + /// TLS session keys. + pub keys: SessionKeys, + /// Client finished verify data. + pub cf_vd: Array, + /// Server finished verify data. + pub sf_vd: Array, +} + /// Session keys computed by the PRF. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Copy)] pub struct SessionKeys { /// Client write key. - pub client_write_key: ValueRef, + pub client_write_key: Array, /// Server write key. - pub server_write_key: ValueRef, + pub server_write_key: Array, /// Client IV. - pub client_iv: ValueRef, + pub client_iv: Array, /// Server IV. - pub server_iv: ValueRef, + pub server_iv: Array, } /// PRF trait for computing TLS PRF. -#[async_trait] -pub trait Prf { +pub trait Prf { /// Sets up the PRF. /// /// # Arguments /// + /// * `vm` - Virtual machine. /// * `pms` - The pre-master secret. - async fn setup(&mut self, pms: ValueRef) -> Result; + fn setup(&mut self, vm: &mut Vm, pms: Array) -> Result; /// Sets the client random. /// - /// This must be set after calling [`Prf::setup`]. - /// /// Only the leader can provide the client random. - async fn set_client_random(&mut self, client_random: Option<[u8; 32]>) -> Result<(), PrfError>; - - /// Preprocesses the PRF. - async fn preprocess(&mut self) -> Result<(), PrfError>; - - /// Computes the client finished verify data. /// /// # Arguments /// - /// * `handshake_hash` - The handshake transcript hash. - async fn compute_client_finished_vd( + /// * `vm` - Virtual machine. + /// * `client_random` - The client random. + fn set_client_random( &mut self, - handshake_hash: [u8; 32], - ) -> Result<[u8; 12], PrfError>; + vm: &mut Vm, + client_random: Option<[u8; 32]>, + ) -> Result<(), PrfError>; - /// Computes the server finished verify data. + /// Sets the server random. /// /// # Arguments /// + /// * `vm` - Virtual machine. + /// * `server_random` - The server random. + fn set_server_random(&mut self, vm: &mut Vm, server_random: [u8; 32]) -> Result<(), PrfError>; + + /// Sets the client finished handshake hash. + /// + /// # Arguments + /// + /// * `vm` - Virtual machine. /// * `handshake_hash` - The handshake transcript hash. - async fn compute_server_finished_vd( - &mut self, - handshake_hash: [u8; 32], - ) -> Result<[u8; 12], PrfError>; + fn set_cf_hash(&mut self, vm: &mut Vm, handshake_hash: [u8; 32]) -> Result<(), PrfError>; - /// Computes the session keys. + /// Sets the server finished handshake hash. /// /// # Arguments /// - /// * `server_random` - The server random. - async fn compute_session_keys( - &mut self, - server_random: [u8; 32], - ) -> Result; + /// * `vm` - Virtual machine. + /// * `handshake_hash` - The handshake transcript hash. + fn set_sf_hash(&mut self, vm: &mut Vm, handshake_hash: [u8; 32]) -> Result<(), PrfError>; } #[cfg(test)] mod tests { use mpz_common::executor::test_st_executor; - use mpz_garble::{config::Role as DEAPRole, protocol::deap::DEAPThread, Decode, Memory}; + use mpz_garble::protocol::semihonest::{Evaluator, Generator}; use hmac_sha256_circuits::{hmac_sha256_partial, prf, session_keys}; - use mpz_ot::ideal::ot::ideal_ot; + use mpz_ot::ideal::cot::ideal_cot; + use mpz_vm_core::{memory::correlated::Delta, prelude::*}; + use rand::{rngs::StdRng, SeedableRng}; use super::*; @@ -113,120 +130,89 @@ mod tests { #[ignore = "expensive"] #[tokio::test] async fn test_prf() { + let mut rng = StdRng::seed_from_u64(0); + let pms = [42u8; 32]; let client_random = [69u8; 32]; let server_random: [u8; 32] = [96u8; 32]; let ms = compute_ms(pms, client_random, server_random); - let (leader_ctx_0, follower_ctx_0) = test_st_executor(128); - let (leader_ctx_1, follower_ctx_1) = test_st_executor(128); + let (mut leader_ctx, mut follower_ctx) = test_st_executor(128); - let (leader_ot_send_0, follower_ot_recv_0) = ideal_ot(); - let (follower_ot_send_0, leader_ot_recv_0) = ideal_ot(); - let (leader_ot_send_1, follower_ot_recv_1) = ideal_ot(); - let (follower_ot_send_1, leader_ot_recv_1) = ideal_ot(); + let delta = Delta::random(&mut rng); + let (ot_send, ot_recv) = ideal_cot(delta.into_inner()); - let leader_thread_0 = DEAPThread::new( - DEAPRole::Leader, - [0u8; 32], - leader_ctx_0, - leader_ot_send_0, - leader_ot_recv_0, - ); - let leader_thread_1 = leader_thread_0 - .new_thread(leader_ctx_1, leader_ot_send_1, leader_ot_recv_1) - .unwrap(); + let mut leader_vm = Generator::new(ot_send, [0u8; 16], delta); + let mut follower_vm = Evaluator::new(ot_recv); - let follower_thread_0 = DEAPThread::new( - DEAPRole::Follower, - [0u8; 32], - follower_ctx_0, - follower_ot_send_0, - follower_ot_recv_0, - ); - let follower_thread_1 = follower_thread_0 - .new_thread(follower_ctx_1, follower_ot_send_1, follower_ot_recv_1) + let leader_pms: Array = leader_vm.alloc().unwrap(); + leader_vm.mark_public(leader_pms).unwrap(); + leader_vm.assign(leader_pms, pms).unwrap(); + leader_vm.commit(leader_pms).unwrap(); + + let follower_pms: Array = follower_vm.alloc().unwrap(); + follower_vm.mark_public(follower_pms).unwrap(); + follower_vm.assign(follower_pms, pms).unwrap(); + follower_vm.commit(follower_pms).unwrap(); + + let mut leader = MpcPrf::new(PrfConfig::builder().role(Role::Leader).build().unwrap()); + let mut follower = MpcPrf::new(PrfConfig::builder().role(Role::Follower).build().unwrap()); + + let leader_output = leader.setup(&mut leader_vm, leader_pms).unwrap(); + let follower_output = follower.setup(&mut follower_vm, follower_pms).unwrap(); + + leader + .set_client_random(&mut leader_vm, Some(client_random)) .unwrap(); + follower.set_client_random(&mut follower_vm, None).unwrap(); - // Set up public PMS for testing. - let leader_pms = leader_thread_0.new_public_input::<[u8; 32]>("pms").unwrap(); - let follower_pms = follower_thread_0 - .new_public_input::<[u8; 32]>("pms") + leader + .set_server_random(&mut leader_vm, server_random) + .unwrap(); + follower + .set_server_random(&mut follower_vm, server_random) .unwrap(); - leader_thread_0.assign(&leader_pms, pms).unwrap(); - follower_thread_0.assign(&follower_pms, pms).unwrap(); + let leader_cwk = leader_vm + .decode(leader_output.keys.client_write_key) + .unwrap(); + let leader_swk = leader_vm + .decode(leader_output.keys.server_write_key) + .unwrap(); + let leader_civ = leader_vm.decode(leader_output.keys.client_iv).unwrap(); + let leader_siv = leader_vm.decode(leader_output.keys.server_iv).unwrap(); - let mut leader = MpcPrf::new( - PrfConfig::builder().role(Role::Leader).build().unwrap(), - leader_thread_0, - leader_thread_1, - ); - let mut follower = MpcPrf::new( - PrfConfig::builder().role(Role::Follower).build().unwrap(), - follower_thread_0, - follower_thread_1, - ); + let follower_cwk = follower_vm + .decode(follower_output.keys.client_write_key) + .unwrap(); + let follower_swk = follower_vm + .decode(follower_output.keys.server_write_key) + .unwrap(); + let follower_civ = follower_vm.decode(follower_output.keys.client_iv).unwrap(); + let follower_siv = follower_vm.decode(follower_output.keys.server_iv).unwrap(); futures::join!( async { - leader.setup(leader_pms).await.unwrap(); - leader.set_client_random(Some(client_random)).await.unwrap(); - leader.preprocess().await.unwrap(); + leader_vm.flush(&mut leader_ctx).await.unwrap(); + leader_vm.execute(&mut leader_ctx).await.unwrap(); + leader_vm.flush(&mut leader_ctx).await.unwrap(); }, async { - follower.setup(follower_pms).await.unwrap(); - follower.set_client_random(None).await.unwrap(); - follower.preprocess().await.unwrap(); + follower_vm.flush(&mut follower_ctx).await.unwrap(); + follower_vm.execute(&mut follower_ctx).await.unwrap(); + follower_vm.flush(&mut follower_ctx).await.unwrap(); } ); - let (leader_session_keys, follower_session_keys) = futures::try_join!( - leader.compute_session_keys(server_random), - follower.compute_session_keys(server_random) - ) - .unwrap(); - - let SessionKeys { - client_write_key: leader_cwk, - server_write_key: leader_swk, - client_iv: leader_civ, - server_iv: leader_siv, - } = leader_session_keys; - - let SessionKeys { - client_write_key: follower_cwk, - server_write_key: follower_swk, - client_iv: follower_civ, - server_iv: follower_siv, - } = follower_session_keys; - - // Decode session keys - let (leader_session_keys, follower_session_keys) = futures::try_join!( - async { - leader - .thread_mut() - .decode(&[leader_cwk, leader_swk, leader_civ, leader_siv]) - .await - }, - async { - follower - .thread_mut() - .decode(&[follower_cwk, follower_swk, follower_civ, follower_siv]) - .await - } - ) - .unwrap(); - - let leader_cwk: [u8; 16] = leader_session_keys[0].clone().try_into().unwrap(); - let leader_swk: [u8; 16] = leader_session_keys[1].clone().try_into().unwrap(); - let leader_civ: [u8; 4] = leader_session_keys[2].clone().try_into().unwrap(); - let leader_siv: [u8; 4] = leader_session_keys[3].clone().try_into().unwrap(); + let leader_cwk = leader_cwk.await.unwrap(); + let leader_swk = leader_swk.await.unwrap(); + let leader_civ = leader_civ.await.unwrap(); + let leader_siv = leader_siv.await.unwrap(); - let follower_cwk: [u8; 16] = follower_session_keys[0].clone().try_into().unwrap(); - let follower_swk: [u8; 16] = follower_session_keys[1].clone().try_into().unwrap(); - let follower_civ: [u8; 4] = follower_session_keys[2].clone().try_into().unwrap(); - let follower_siv: [u8; 4] = follower_session_keys[3].clone().try_into().unwrap(); + let follower_cwk = follower_cwk.await.unwrap(); + let follower_swk = follower_swk.await.unwrap(); + let follower_civ = follower_civ.await.unwrap(); + let follower_siv = follower_siv.await.unwrap(); let (expected_cwk, expected_swk, expected_civ, expected_siv) = session_keys(pms, client_random, server_random); @@ -244,24 +230,43 @@ mod tests { let cf_hs_hash = [1u8; 32]; let sf_hs_hash = [2u8; 32]; - let (cf_vd, _) = futures::try_join!( - leader.compute_client_finished_vd(cf_hs_hash), - follower.compute_client_finished_vd(cf_hs_hash) - ) - .unwrap(); + leader.set_cf_hash(&mut leader_vm, cf_hs_hash).unwrap(); + leader.set_sf_hash(&mut leader_vm, sf_hs_hash).unwrap(); - let expected_cf_vd = compute_vd(ms, b"client finished", cf_hs_hash); + follower.set_cf_hash(&mut follower_vm, cf_hs_hash).unwrap(); + follower.set_sf_hash(&mut follower_vm, sf_hs_hash).unwrap(); + + let leader_cf_vd = leader_vm.decode(leader_output.cf_vd).unwrap(); + let leader_sf_vd = leader_vm.decode(leader_output.sf_vd).unwrap(); + + let follower_cf_vd = follower_vm.decode(follower_output.cf_vd).unwrap(); + let follower_sf_vd = follower_vm.decode(follower_output.sf_vd).unwrap(); + + futures::join!( + async { + leader_vm.flush(&mut leader_ctx).await.unwrap(); + leader_vm.execute(&mut leader_ctx).await.unwrap(); + leader_vm.flush(&mut leader_ctx).await.unwrap(); + }, + async { + follower_vm.flush(&mut follower_ctx).await.unwrap(); + follower_vm.execute(&mut follower_ctx).await.unwrap(); + follower_vm.flush(&mut follower_ctx).await.unwrap(); + } + ); - assert_eq!(cf_vd, expected_cf_vd); + let leader_cf_vd = leader_cf_vd.await.unwrap(); + let leader_sf_vd = leader_sf_vd.await.unwrap(); - let (sf_vd, _) = futures::try_join!( - leader.compute_server_finished_vd(sf_hs_hash), - follower.compute_server_finished_vd(sf_hs_hash) - ) - .unwrap(); + let follower_cf_vd = follower_cf_vd.await.unwrap(); + let follower_sf_vd = follower_sf_vd.await.unwrap(); + let expected_cf_vd = compute_vd(ms, b"client finished", cf_hs_hash); let expected_sf_vd = compute_vd(ms, b"server finished", sf_hs_hash); - assert_eq!(sf_vd, expected_sf_vd); + assert_eq!(leader_cf_vd, expected_cf_vd); + assert_eq!(leader_sf_vd, expected_sf_vd); + assert_eq!(follower_cf_vd, expected_cf_vd); + assert_eq!(follower_sf_vd, expected_sf_vd); } } diff --git a/crates/components/hmac-sha256/src/prf.rs b/crates/components/hmac-sha256/src/prf.rs index 77afab3ea..8f38fbec3 100644 --- a/crates/components/hmac-sha256/src/prf.rs +++ b/crates/components/hmac-sha256/src/prf.rs @@ -3,60 +3,65 @@ use std::{ sync::{Arc, OnceLock}, }; -use async_trait::async_trait; - use hmac_sha256_circuits::{build_session_keys, build_verify_data}; use mpz_circuits::Circuit; use mpz_common::cpu::CpuBackend; -use mpz_garble::{config::Visibility, value::ValueRef, Decode, Execute, Load, Memory}; +use mpz_vm_core::{ + memory::{ + binary::{Binary, U32, U8}, + Array, View, + }, + prelude::*, + Call, Vm as VmTrait, +}; use tracing::instrument; -use crate::{Prf, PrfConfig, PrfError, Role, SessionKeys, CF_LABEL, SF_LABEL}; +use crate::{Prf, PrfConfig, PrfError, PrfOutput, Role, SessionKeys, CF_LABEL, SF_LABEL}; -/// Circuit for computing TLS session keys. -static SESSION_KEYS_CIRC: OnceLock> = OnceLock::new(); -/// Circuit for computing TLS client verify data. -static CLIENT_VD_CIRC: OnceLock> = OnceLock::new(); -/// Circuit for computing TLS server verify data. -static SERVER_VD_CIRC: OnceLock> = OnceLock::new(); - -#[derive(Debug)] -pub(crate) struct Randoms { - pub(crate) client_random: ValueRef, - pub(crate) server_random: ValueRef, +pub(crate) struct Circuits { + session_keys: Arc, + client_vd: Arc, + server_vd: Arc, } -#[derive(Debug, Clone)] -pub(crate) struct HashState { - pub(crate) ms_outer_hash_state: ValueRef, - pub(crate) ms_inner_hash_state: ValueRef, -} +impl Circuits { + pub(crate) async fn get() -> &'static Self { + static CIRCUITS: OnceLock = OnceLock::new(); + if let Some(circuits) = CIRCUITS.get() { + return circuits; + } -#[derive(Debug)] -pub(crate) struct VerifyData { - pub(crate) handshake_hash: ValueRef, - pub(crate) vd: ValueRef, + let (session_keys, client_vd, server_vd) = futures::join!( + CpuBackend::blocking(build_session_keys), + CpuBackend::blocking(|| build_verify_data(CF_LABEL)), + CpuBackend::blocking(|| build_verify_data(SF_LABEL)), + ); + + _ = CIRCUITS.set(Circuits { + session_keys, + client_vd, + server_vd, + }); + + CIRCUITS.get().unwrap() + } } #[derive(Debug)] pub(crate) enum State { Initialized, SessionKeys { - pms: ValueRef, - randoms: Randoms, - hash_state: HashState, - keys: crate::SessionKeys, - cf_vd: VerifyData, - sf_vd: VerifyData, + client_random: Array, + server_random: Array, + cf_hash: Array, + sf_hash: Array, }, ClientFinished { - hash_state: HashState, - cf_vd: VerifyData, - sf_vd: VerifyData, + cf_hash: Array, + sf_hash: Array, }, ServerFinished { - hash_state: HashState, - sf_vd: VerifyData, + sf_hash: Array, }, Complete, Error, @@ -69,14 +74,12 @@ impl State { } /// MPC PRF for computing TLS HMAC-SHA256 PRF. -pub struct MpcPrf { +pub struct MpcPrf { config: PrfConfig, state: State, - thread_0: E, - thread_1: E, } -impl Debug for MpcPrf { +impl Debug for MpcPrf { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("MpcPrf") .field("config", &self.config) @@ -85,359 +88,178 @@ impl Debug for MpcPrf { } } -impl MpcPrf -where - E: Load + Memory + Execute + Decode + Send, -{ +impl MpcPrf { /// Creates a new instance of the PRF. - pub fn new(config: PrfConfig, thread_0: E, thread_1: E) -> MpcPrf { + pub fn new(config: PrfConfig) -> MpcPrf { MpcPrf { config, state: State::Initialized, - thread_0, - thread_1, } } - - /// Returns a mutable reference to the MPC thread. - pub fn thread_mut(&mut self) -> &mut E { - &mut self.thread_0 - } - - /// Executes a circuit which computes TLS session keys. - #[instrument(level = "debug", skip_all, err)] - async fn execute_session_keys( - &mut self, - server_random: [u8; 32], - ) -> Result { - let State::SessionKeys { - pms, - randoms: randoms_refs, - hash_state, - keys, - cf_vd, - sf_vd, - } = self.state.take() - else { - return Err(PrfError::state("session keys not initialized")); - }; - - let circ = SESSION_KEYS_CIRC - .get() - .expect("session keys circuit is set"); - - self.thread_0 - .assign(&randoms_refs.server_random, server_random)?; - - self.thread_0 - .execute( - circ.clone(), - &[pms, randoms_refs.client_random, randoms_refs.server_random], - &[ - keys.client_write_key.clone(), - keys.server_write_key.clone(), - keys.client_iv.clone(), - keys.server_iv.clone(), - hash_state.ms_outer_hash_state.clone(), - hash_state.ms_inner_hash_state.clone(), - ], - ) - .await?; - - self.state = State::ClientFinished { - hash_state, - cf_vd, - sf_vd, - }; - - Ok(keys) - } - - #[instrument(level = "debug", skip_all, err)] - async fn execute_cf_vd(&mut self, handshake_hash: [u8; 32]) -> Result<[u8; 12], PrfError> { - let State::ClientFinished { - hash_state, - cf_vd, - sf_vd, - } = self.state.take() - else { - return Err(PrfError::state("PRF not in client finished state")); - }; - - let circ = CLIENT_VD_CIRC.get().expect("client vd circuit is set"); - - self.thread_0 - .assign(&cf_vd.handshake_hash, handshake_hash)?; - - self.thread_0 - .execute( - circ.clone(), - &[ - hash_state.ms_outer_hash_state.clone(), - hash_state.ms_inner_hash_state.clone(), - cf_vd.handshake_hash, - ], - &[cf_vd.vd.clone()], - ) - .await?; - - let mut outputs = self.thread_0.decode(&[cf_vd.vd]).await?; - let vd: [u8; 12] = outputs.remove(0).try_into().expect("vd is 12 bytes"); - - self.state = State::ServerFinished { hash_state, sf_vd }; - - Ok(vd) - } - - #[instrument(level = "debug", skip_all, err)] - async fn execute_sf_vd(&mut self, handshake_hash: [u8; 32]) -> Result<[u8; 12], PrfError> { - let State::ServerFinished { hash_state, sf_vd } = self.state.take() else { - return Err(PrfError::state("PRF not in server finished state")); - }; - - let circ = SERVER_VD_CIRC.get().expect("server vd circuit is set"); - - self.thread_0 - .assign(&sf_vd.handshake_hash, handshake_hash)?; - - self.thread_0 - .execute( - circ.clone(), - &[ - hash_state.ms_outer_hash_state, - hash_state.ms_inner_hash_state, - sf_vd.handshake_hash, - ], - &[sf_vd.vd.clone()], - ) - .await?; - - let mut outputs = self.thread_0.decode(&[sf_vd.vd]).await?; - let vd: [u8; 12] = outputs.remove(0).try_into().expect("vd is 12 bytes"); - - self.state = State::Complete; - - Ok(vd) - } } -#[async_trait] -impl Prf for MpcPrf +impl Prf for MpcPrf where - E: Memory + Load + Execute + Decode + Send, + Vm: VmTrait + View, { #[instrument(level = "debug", skip_all, err)] - async fn setup(&mut self, pms: ValueRef) -> Result { + fn setup(&mut self, vm: &mut Vm, pms: Array) -> Result { let State::Initialized = self.state.take() else { return Err(PrfError::state("PRF not in initialized state")); }; - let thread = &mut self.thread_0; - - let randoms = Randoms { - // The client random is kept private so that the handshake transcript - // hashes do not leak information about the server's identity. - client_random: thread.new_input::<[u8; 32]>( - "client_random", - match self.config.role { - Role::Leader => Visibility::Private, - Role::Follower => Visibility::Blind, - }, - )?, - server_random: thread.new_input::<[u8; 32]>("server_random", Visibility::Public)?, - }; + let circuits = futures::executor::block_on(Circuits::get()); - let keys = SessionKeys { - client_write_key: thread.new_output::<[u8; 16]>("client_write_key")?, - server_write_key: thread.new_output::<[u8; 16]>("server_write_key")?, - client_iv: thread.new_output::<[u8; 4]>("client_write_iv")?, - server_iv: thread.new_output::<[u8; 4]>("server_write_iv")?, - }; + let client_random = vm.alloc().map_err(PrfError::vm)?; + let server_random = vm.alloc().map_err(PrfError::vm)?; - let hash_state = HashState { - ms_outer_hash_state: thread.new_output::<[u32; 8]>("ms_outer_hash_state")?, - ms_inner_hash_state: thread.new_output::<[u32; 8]>("ms_inner_hash_state")?, - }; + // The client random is kept private so that the handshake transcript + // hashes do not leak information about the server's identity. + match self.config.role { + Role::Leader => vm.mark_private(client_random), + Role::Follower => vm.mark_blind(client_random), + } + .map_err(PrfError::vm)?; + + vm.mark_public(server_random).map_err(PrfError::vm)?; + + #[allow(clippy::type_complexity)] + let ( + client_write_key, + server_write_key, + client_iv, + server_iv, + ms_outer_hash_state, + ms_inner_hash_state, + ): ( + Array, + Array, + Array, + Array, + Array, + Array, + ) = vm + .call( + Call::new(circuits.session_keys.clone()) + .arg(pms) + .arg(client_random) + .arg(server_random) + .build() + .map_err(PrfError::vm)?, + ) + .map_err(PrfError::vm)?; - let cf_vd = VerifyData { - handshake_hash: thread.new_input::<[u8; 32]>("cf_hash", Visibility::Public)?, - vd: thread.new_output::<[u8; 12]>("cf_vd")?, + let keys = SessionKeys { + client_write_key, + server_write_key, + client_iv, + server_iv, }; - let sf_vd = VerifyData { - handshake_hash: thread.new_input::<[u8; 32]>("sf_hash", Visibility::Public)?, - vd: thread.new_output::<[u8; 12]>("sf_vd")?, - }; + let cf_hash = vm.alloc().map_err(PrfError::vm)?; + vm.mark_public(cf_hash).map_err(PrfError::vm)?; + + let cf_vd = vm + .call( + Call::new(circuits.client_vd.clone()) + .arg(ms_outer_hash_state) + .arg(ms_inner_hash_state) + .arg(cf_hash) + .build() + .map_err(PrfError::vm)?, + ) + .map_err(PrfError::vm)?; + + let sf_hash = vm.alloc().map_err(PrfError::vm)?; + vm.mark_public(sf_hash).map_err(PrfError::vm)?; + + let sf_vd = vm + .call( + Call::new(circuits.server_vd.clone()) + .arg(ms_outer_hash_state) + .arg(ms_inner_hash_state) + .arg(sf_hash) + .build() + .map_err(PrfError::vm)?, + ) + .map_err(PrfError::vm)?; self.state = State::SessionKeys { - pms, - randoms, - hash_state, - keys: keys.clone(), - cf_vd, - sf_vd, + client_random, + server_random, + cf_hash, + sf_hash, }; - Ok(keys) + Ok(PrfOutput { keys, cf_vd, sf_vd }) } #[instrument(level = "debug", skip_all, err)] - async fn set_client_random(&mut self, client_random: Option<[u8; 32]>) -> Result<(), PrfError> { - let State::SessionKeys { randoms, .. } = &self.state else { + fn set_client_random(&mut self, vm: &mut Vm, random: Option<[u8; 32]>) -> Result<(), PrfError> { + let State::SessionKeys { client_random, .. } = &self.state else { return Err(PrfError::state("PRF not set up")); }; if self.config.role == Role::Leader { - let Some(client_random) = client_random else { + let Some(random) = random else { return Err(PrfError::role("leader must provide client random")); }; - self.thread_0 - .assign(&randoms.client_random, client_random)?; - } else if client_random.is_some() { + vm.assign(*client_random, random).map_err(PrfError::vm)?; + } else if random.is_some() { return Err(PrfError::role("only leader can set client random")); } - self.thread_0 - .commit(&[randoms.client_random.clone()]) - .await?; + vm.commit(*client_random).map_err(PrfError::vm)?; Ok(()) } #[instrument(level = "debug", skip_all, err)] - async fn preprocess(&mut self) -> Result<(), PrfError> { + fn set_server_random(&mut self, vm: &mut Vm, random: [u8; 32]) -> Result<(), PrfError> { let State::SessionKeys { - pms, - randoms, - hash_state, - keys, - cf_vd, - sf_vd, + server_random, + cf_hash, + sf_hash, + .. } = self.state.take() else { return Err(PrfError::state("PRF not set up")); }; - // Builds all circuits in parallel and preprocesses the session keys circuit. - futures::try_join!( - async { - if SESSION_KEYS_CIRC.get().is_none() { - _ = SESSION_KEYS_CIRC.set(CpuBackend::blocking(build_session_keys).await); - } - - let circ = SESSION_KEYS_CIRC - .get() - .expect("session keys circuit should be built"); - - self.thread_0 - .load( - circ.clone(), - &[ - pms.clone(), - randoms.client_random.clone(), - randoms.server_random.clone(), - ], - &[ - keys.client_write_key.clone(), - keys.server_write_key.clone(), - keys.client_iv.clone(), - keys.server_iv.clone(), - hash_state.ms_outer_hash_state.clone(), - hash_state.ms_inner_hash_state.clone(), - ], - ) - .await?; - - Ok::<_, PrfError>(()) - }, - async { - if CLIENT_VD_CIRC.get().is_none() { - _ = CLIENT_VD_CIRC - .set(CpuBackend::blocking(move || build_verify_data(CF_LABEL)).await); - } - - Ok::<_, PrfError>(()) - }, - async { - if SERVER_VD_CIRC.get().is_none() { - _ = SERVER_VD_CIRC - .set(CpuBackend::blocking(move || build_verify_data(SF_LABEL)).await); - } - - Ok::<_, PrfError>(()) - } - )?; - - // Finishes preprocessing the verify data circuits. - futures::try_join!( - async { - self.thread_0 - .load( - CLIENT_VD_CIRC - .get() - .expect("client finished circuit should be built") - .clone(), - &[ - hash_state.ms_outer_hash_state.clone(), - hash_state.ms_inner_hash_state.clone(), - cf_vd.handshake_hash.clone(), - ], - &[cf_vd.vd.clone()], - ) - .await - }, - async { - self.thread_1 - .load( - SERVER_VD_CIRC - .get() - .expect("server finished circuit should be built") - .clone(), - &[ - hash_state.ms_outer_hash_state.clone(), - hash_state.ms_inner_hash_state.clone(), - sf_vd.handshake_hash.clone(), - ], - &[sf_vd.vd.clone()], - ) - .await - } - )?; + vm.assign(server_random, random).map_err(PrfError::vm)?; + vm.commit(server_random).map_err(PrfError::vm)?; - self.state = State::SessionKeys { - pms, - randoms, - hash_state, - keys, - cf_vd, - sf_vd, - }; + self.state = State::ClientFinished { cf_hash, sf_hash }; Ok(()) } #[instrument(level = "debug", skip_all, err)] - async fn compute_client_finished_vd( - &mut self, - handshake_hash: [u8; 32], - ) -> Result<[u8; 12], PrfError> { - self.execute_cf_vd(handshake_hash).await - } + fn set_cf_hash(&mut self, vm: &mut Vm, handshake_hash: [u8; 32]) -> Result<(), PrfError> { + let State::ClientFinished { cf_hash, sf_hash } = self.state.take() else { + return Err(PrfError::state("PRF not in client finished state")); + }; - #[instrument(level = "debug", skip_all, err)] - async fn compute_server_finished_vd( - &mut self, - handshake_hash: [u8; 32], - ) -> Result<[u8; 12], PrfError> { - self.execute_sf_vd(handshake_hash).await + vm.assign(cf_hash, handshake_hash).map_err(PrfError::vm)?; + vm.commit(cf_hash).map_err(PrfError::vm)?; + + self.state = State::ServerFinished { sf_hash }; + + Ok(()) } #[instrument(level = "debug", skip_all, err)] - async fn compute_session_keys( - &mut self, - server_random: [u8; 32], - ) -> Result { - self.execute_session_keys(server_random).await + fn set_sf_hash(&mut self, vm: &mut Vm, handshake_hash: [u8; 32]) -> Result<(), PrfError> { + let State::ServerFinished { sf_hash } = self.state.take() else { + return Err(PrfError::state("PRF not in server finished state")); + }; + + vm.assign(sf_hash, handshake_hash).map_err(PrfError::vm)?; + vm.commit(sf_hash).map_err(PrfError::vm)?; + + self.state = State::Complete; + + Ok(()) } } From 9f7cef22562dc9baf63bbbcb22982030cbbf32ef Mon Sep 17 00:00:00 2001 From: th4s Date: Wed, 15 Jan 2025 13:02:09 +0100 Subject: [PATCH 2/4] refactor: remove preprocessing bench --- crates/components/hmac-sha256/benches/prf.rs | 69 -------------------- 1 file changed, 69 deletions(-) diff --git a/crates/components/hmac-sha256/benches/prf.rs b/crates/components/hmac-sha256/benches/prf.rs index 6b826c4bf..597b442a8 100644 --- a/crates/components/hmac-sha256/benches/prf.rs +++ b/crates/components/hmac-sha256/benches/prf.rs @@ -16,81 +16,12 @@ fn criterion_benchmark(c: &mut Criterion) { group.sample_size(10); let rt = tokio::runtime::Runtime::new().unwrap(); - //group.bench_function("prf_preprocess", |b| b.to_async(&rt).iter(preprocess)); group.bench_function("prf", |b| b.to_async(&rt).iter(prf)); } criterion_group!(benches, criterion_benchmark); criterion_main!(benches); -// async fn preprocess() { -// let (mut leader_exec, mut follower_exec) = test_mt_executor(128); - -// let (leader_ot_send_0, follower_ot_recv_0) = ideal_ot(); -// let (follower_ot_send_0, leader_ot_recv_0) = ideal_ot(); -// let (leader_ot_send_1, follower_ot_recv_1) = ideal_ot(); -// let (follower_ot_send_1, leader_ot_recv_1) = ideal_ot(); - -// let leader_thread_0 = DEAPThread::new( -// DEAPRole::Leader, -// [0u8; 32], -// leader_exec.new_thread().await.unwrap(), -// leader_ot_send_0, -// leader_ot_recv_0, -// ); -// let leader_thread_1 = leader_thread_0 -// .new_thread( -// leader_exec.new_thread().await.unwrap(), -// leader_ot_send_1, -// leader_ot_recv_1, -// ) -// .unwrap(); - -// let follower_thread_0 = DEAPThread::new( -// DEAPRole::Follower, -// [0u8; 32], -// follower_exec.new_thread().await.unwrap(), -// follower_ot_send_0, -// follower_ot_recv_0, -// ); -// let follower_thread_1 = follower_thread_0 -// .new_thread( -// follower_exec.new_thread().await.unwrap(), -// follower_ot_send_1, -// follower_ot_recv_1, -// ) -// .unwrap(); - -// let leader_pms = leader_thread_0.new_public_input::<[u8; -// 32]>("pms").unwrap(); let follower_pms = follower_thread_0 -// .new_public_input::<[u8; 32]>("pms") -// .unwrap(); - -// let mut leader = MpcPrf::new( -// PrfConfig::builder().role(Role::Leader).build().unwrap(), -// leader_thread_0, -// leader_thread_1, -// ); -// let mut follower = MpcPrf::new( -// PrfConfig::builder().role(Role::Follower).build().unwrap(), -// follower_thread_0, -// follower_thread_1, -// ); - -// futures::join!( -// async { -// leader.setup(leader_pms).await.unwrap(); -// leader.set_client_random(Some([0u8; 32])).await.unwrap(); -// leader.preprocess().await.unwrap(); -// }, -// async { -// follower.setup(follower_pms).await.unwrap(); -// follower.set_client_random(None).await.unwrap(); -// follower.preprocess().await.unwrap(); -// } -// ); -// } - async fn prf() { let mut rng = StdRng::seed_from_u64(0); From 223e23b2f085a011010ac0ca6899c4990a3885b0 Mon Sep 17 00:00:00 2001 From: th4s Date: Thu, 16 Jan 2025 10:34:59 +0100 Subject: [PATCH 3/4] fix: fix feature flags --- crates/components/hmac-sha256/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/components/hmac-sha256/Cargo.toml b/crates/components/hmac-sha256/Cargo.toml index d19ce0495..dbaf80581 100644 --- a/crates/components/hmac-sha256/Cargo.toml +++ b/crates/components/hmac-sha256/Cargo.toml @@ -21,7 +21,7 @@ tlsn-hmac-sha256-circuits = { workspace = true } mpz-vm-core = { workspace = true } mpz-circuits = { workspace = true } -mpz-common = { workspace = true } +mpz-common = { workspace = true, features = ["cpu"] } derive_builder = { workspace = true } thiserror = { workspace = true } From ce5cbba5040b9c7f642b4be892dc3af5d2935d08 Mon Sep 17 00:00:00 2001 From: sinu <65924192+sinui0@users.noreply.github.com> Date: Tue, 21 Jan 2025 16:06:23 -0800 Subject: [PATCH 4/4] clean up attributes --- crates/components/hmac-sha256/benches/prf.rs | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/crates/components/hmac-sha256/benches/prf.rs b/crates/components/hmac-sha256/benches/prf.rs index 597b442a8..af832fb5c 100644 --- a/crates/components/hmac-sha256/benches/prf.rs +++ b/crates/components/hmac-sha256/benches/prf.rs @@ -1,3 +1,5 @@ +#![allow(clippy::let_underscore_future)] + use criterion::{criterion_group, criterion_main, Criterion}; use hmac_sha256::{MpcPrf, Prf, PrfConfig, Role}; @@ -67,30 +69,22 @@ async fn prf() { .set_server_random(&mut follower_vm, server_random) .unwrap(); - #[allow(clippy::let_underscore_future)] let _ = leader_vm .decode(leader_output.keys.client_write_key) .unwrap(); - #[allow(clippy::let_underscore_future)] let _ = leader_vm .decode(leader_output.keys.server_write_key) .unwrap(); - #[allow(clippy::let_underscore_future)] let _ = leader_vm.decode(leader_output.keys.client_iv).unwrap(); - #[allow(clippy::let_underscore_future)] let _ = leader_vm.decode(leader_output.keys.server_iv).unwrap(); - #[allow(clippy::let_underscore_future)] let _ = follower_vm .decode(follower_output.keys.client_write_key) .unwrap(); - #[allow(clippy::let_underscore_future)] let _ = follower_vm .decode(follower_output.keys.server_write_key) .unwrap(); - #[allow(clippy::let_underscore_future)] let _ = follower_vm.decode(follower_output.keys.client_iv).unwrap(); - #[allow(clippy::let_underscore_future)] let _ = follower_vm.decode(follower_output.keys.server_iv).unwrap(); futures::join!( @@ -115,14 +109,10 @@ async fn prf() { follower.set_cf_hash(&mut follower_vm, cf_hs_hash).unwrap(); follower.set_sf_hash(&mut follower_vm, sf_hs_hash).unwrap(); - #[allow(clippy::let_underscore_future)] let _ = leader_vm.decode(leader_output.cf_vd).unwrap(); - #[allow(clippy::let_underscore_future)] let _ = leader_vm.decode(leader_output.sf_vd).unwrap(); - #[allow(clippy::let_underscore_future)] let _ = follower_vm.decode(follower_output.cf_vd).unwrap(); - #[allow(clippy::let_underscore_future)] let _ = follower_vm.decode(follower_output.sf_vd).unwrap(); futures::join!(