From ce323b6a70bc9ace4d458c94ce62d1f086d71d34 Mon Sep 17 00:00:00 2001 From: sinu <65924192+sinui0@users.noreply.github.com> Date: Sat, 25 Jan 2025 10:39:58 -0800 Subject: [PATCH] refactor: key exchange interface --- crates/components/key-exchange/Cargo.toml | 2 +- crates/components/key-exchange/src/error.rs | 7 +- .../components/key-exchange/src/exchange.rs | 534 ++++++++---------- crates/components/key-exchange/src/lib.rs | 43 +- crates/components/key-exchange/src/mock.rs | 5 +- .../key-exchange/src/point_addition.rs | 25 +- 6 files changed, 275 insertions(+), 341 deletions(-) diff --git a/crates/components/key-exchange/Cargo.toml b/crates/components/key-exchange/Cargo.toml index 7f138fc22..3bcdbe462 100644 --- a/crates/components/key-exchange/Cargo.toml +++ b/crates/components/key-exchange/Cargo.toml @@ -37,6 +37,6 @@ tokio = { workspace = true, features = ["sync"] } mpz-ot = { workspace = true, features = ["ideal"] } mpz-garble = { workspace = true } -rand_chacha = { workspace = true } rand_core = { workspace = true } tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] } +rstest = { workspace = true } diff --git a/crates/components/key-exchange/src/error.rs b/crates/components/key-exchange/src/error.rs index 24e64dbc3..241908e7d 100644 --- a/crates/components/key-exchange/src/error.rs +++ b/crates/components/key-exchange/src/error.rs @@ -3,7 +3,7 @@ use std::error::Error; /// MPC-TLS protocol error. #[derive(Debug, thiserror::Error)] #[error(transparent)] -pub struct KeyExchangeError(#[from] ErrorRepr); +pub struct KeyExchangeError(#[from] pub(crate) ErrorRepr); #[derive(Debug, thiserror::Error)] #[error("key exchange error: {0}")] @@ -66,11 +66,6 @@ impl KeyExchangeError { { Self(ErrorRepr::Key(err.into())) } - - #[cfg(test)] - pub(crate) fn kind(&self) -> &ErrorRepr { - &self.0 - } } impl From for KeyExchangeError { diff --git a/crates/components/key-exchange/src/exchange.rs b/crates/components/key-exchange/src/exchange.rs index 706f03423..cd1ca4115 100644 --- a/crates/components/key-exchange/src/exchange.rs +++ b/crates/components/key-exchange/src/exchange.rs @@ -13,10 +13,10 @@ use mpz_core::bitvec::BitVec; use mpz_fields::{p256::P256, Field}; use mpz_memory_core::{ binary::{Binary, U8}, - Array, DecodeFutureTyped, Memory, MemoryExt, View, ViewExt, + Array, DecodeFutureTyped, MemoryExt, ViewExt, }; use mpz_share_conversion::{AdditiveToMultiplicative, MultiplicativeToAdditive, ShareConvert}; -use mpz_vm_core::{CallBuilder, Vm, VmExt}; +use mpz_vm_core::{CallBuilder, CallableExt, Vm}; use crate::{ circuit::build_pms_circuit, @@ -107,12 +107,18 @@ impl MpcKeyExchange { } } +#[async_trait] impl KeyExchange for MpcKeyExchange where - C0: ShareConvert + Send, - C1: ShareConvert + Send, + C0: ShareConvert + Flush + Send + 'static, + C1: ShareConvert + Flush + Send + 'static, { - fn alloc(&mut self) -> Result<(), KeyExchangeError> { + #[instrument(level = "debug", skip_all, err)] + fn alloc(&mut self, vm: &mut dyn Vm) -> Result { + let State::Initialized = self.state.take() else { + return Err(KeyExchangeError::state("should be in Initialized state")); + }; + let mut converter_0 = self.converter_0.try_lock().unwrap(); let mut converter_1 = self.converter_1.try_lock().unwrap(); @@ -127,9 +133,63 @@ where AdditiveToMultiplicative::alloc(&mut *converter_1, 2) .map_err(KeyExchangeError::share_conversion)?; - Ok(()) + let (share_a0, share_b0, share_a1, share_b1) = match self.config.role() { + Role::Leader => { + let share_a0: Array = vm.alloc().map_err(KeyExchangeError::vm)?; + vm.mark_private(share_a0).map_err(KeyExchangeError::vm)?; + + let share_b0: Array = vm.alloc().map_err(KeyExchangeError::vm)?; + vm.mark_blind(share_b0).map_err(KeyExchangeError::vm)?; + + let share_a1: Array = vm.alloc().map_err(KeyExchangeError::vm)?; + vm.mark_private(share_a1).map_err(KeyExchangeError::vm)?; + + let share_b1: Array = vm.alloc().map_err(KeyExchangeError::vm)?; + vm.mark_blind(share_b1).map_err(KeyExchangeError::vm)?; + + (share_a0, share_b0, share_a1, share_b1) + } + Role::Follower => { + let share_a0: Array = vm.alloc().map_err(KeyExchangeError::vm)?; + vm.mark_blind(share_a0).map_err(KeyExchangeError::vm)?; + + let share_b0: Array = vm.alloc().map_err(KeyExchangeError::vm)?; + vm.mark_private(share_b0).map_err(KeyExchangeError::vm)?; + + let share_a1: Array = vm.alloc().map_err(KeyExchangeError::vm)?; + vm.mark_blind(share_a1).map_err(KeyExchangeError::vm)?; + + let share_b1: Array = vm.alloc().map_err(KeyExchangeError::vm)?; + vm.mark_private(share_b1).map_err(KeyExchangeError::vm)?; + + (share_a0, share_b0, share_a1, share_b1) + } + }; + + let pms_circuit = build_pms_circuit(); + let pms_call = CallBuilder::new(pms_circuit) + .arg(share_a0) + .arg(share_b0) + .arg(share_a1) + .arg(share_b1) + .build() + .map_err(KeyExchangeError::vm)?; + + let (pms, _, eq): (Array, Array, Array) = + vm.call(pms_call).map_err(KeyExchangeError::vm)?; + + self.state = State::Setup { + share_a0, + share_b0, + share_a1, + share_b1, + eq, + }; + + Ok(pms) } + #[instrument(level = "debug", skip_all, err)] fn set_server_key(&mut self, server_key: PublicKey) -> Result<(), KeyExchangeError> { let Role::Leader = self.config.role() else { return Err(KeyExchangeError::role("follower cannot set server key")); @@ -167,77 +227,93 @@ where } #[instrument(level = "debug", skip_all, err)] - fn setup(&mut self, vm: &mut V) -> Result - where - V: Vm + Memory + View, - { - let State::Initialized = self.state.take() else { - return Err(KeyExchangeError::state( - "should be in Initialized state to call setup", - )); + async fn setup(&mut self, ctx: &mut Context) -> Result<(), KeyExchangeError> { + let State::Setup { + share_a0, + share_b0, + share_a1, + share_b1, + eq, + } = self.state.take() + else { + return Err(KeyExchangeError::state("should be in setup state")); }; - let (share_a0, share_b0, share_a1, share_b1) = match self.config.role() { - Role::Leader => { - let share_a0: Array = vm.alloc().map_err(KeyExchangeError::vm)?; - vm.mark_private(share_a0).map_err(KeyExchangeError::vm)?; - - let share_b0: Array = vm.alloc().map_err(KeyExchangeError::vm)?; - vm.mark_blind(share_b0).map_err(KeyExchangeError::vm)?; - - let share_a1: Array = vm.alloc().map_err(KeyExchangeError::vm)?; - vm.mark_private(share_a1).map_err(KeyExchangeError::vm)?; + let follower_key = match self.config.role() { + Role::Leader => ctx.io_mut().expect_next().await?, + Role::Follower => { + let follower_key = self.private_key.public_key(); + ctx.io_mut().send(follower_key).await?; + follower_key + } + }; - let share_b1: Array = vm.alloc().map_err(KeyExchangeError::vm)?; - vm.mark_blind(share_b1).map_err(KeyExchangeError::vm)?; + self.state = State::FollowerKey { + follower_key, + share_a0, + share_b0, + share_a1, + share_b1, + eq, + }; - (share_a0, share_b0, share_a1, share_b1) - } - Role::Follower => { - let share_a0: Array = vm.alloc().map_err(KeyExchangeError::vm)?; - vm.mark_blind(share_a0).map_err(KeyExchangeError::vm)?; + Ok(()) + } - let share_b0: Array = vm.alloc().map_err(KeyExchangeError::vm)?; - vm.mark_private(share_b0).map_err(KeyExchangeError::vm)?; + #[instrument(level = "debug", skip_all, err)] + async fn compute_shares(&mut self, ctx: &mut Context) -> Result<(), KeyExchangeError> { + let State::FollowerKey { + share_a0, + share_b0, + share_a1, + share_b1, + eq, + .. + } = self.state.take() + else { + return Err(KeyExchangeError::state( + "can not compute shares before performing setup", + )); + }; - let share_a1: Array = vm.alloc().map_err(KeyExchangeError::vm)?; - vm.mark_blind(share_a1).map_err(KeyExchangeError::vm)?; + let server_key = match self.config.role() { + Role::Leader => { + let server_key = self + .server_key + .ok_or_else(|| KeyExchangeError::role("server key is not set"))?; - let share_b1: Array = vm.alloc().map_err(KeyExchangeError::vm)?; - vm.mark_private(share_b1).map_err(KeyExchangeError::vm)?; + ctx.io_mut().send(server_key).await?; - (share_a0, share_b0, share_a1, share_b1) + server_key } + Role::Follower => ctx.io_mut().expect_next().await?, }; - let pms_circuit = build_pms_circuit(); - let pms_call = CallBuilder::new(pms_circuit) - .arg(share_a0) - .arg(share_b0) - .arg(share_a1) - .arg(share_b1) - .build() - .map_err(KeyExchangeError::vm)?; - - let (pms, _, eq): (Array, Array, Array) = - vm.call(pms_call).map_err(KeyExchangeError::vm)?; + let (pms_0, pms_1) = compute_ec_shares( + ctx, + self.config.role(), + self.converter_0.clone(), + self.converter_1.clone(), + self.private_key.clone(), + server_key, + ) + .await?; - self.state = State::Setup { + self.state = State::ComputedECShares { share_a0, share_b0, share_a1, share_b1, eq, + pms_0, + pms_1, }; - Ok(Pms::new(pms)) + Ok(()) } - #[instrument(level = "debug", skip_all, err)] - fn compute_pms(&mut self, vm: &mut V) -> Result<(), KeyExchangeError> - where - V: Vm + Memory + View, - { + #[instrument(level = "trace", skip_all, err)] + fn assign(&mut self, vm: &mut dyn Vm) -> Result<(), KeyExchangeError> { let State::ComputedECShares { share_a0, share_b0, @@ -296,118 +372,29 @@ where Ok(()) } -} -#[async_trait] -impl Flush for MpcKeyExchange -where - Ctx: Context, - C0: ShareConvert + Flush + Send + 'static, - >::Future: Send, - >::Future: Send, - C1: ShareConvert + Flush + Send + 'static, - >::Future: Send, - >::Future: Send, -{ - type Error = KeyExchangeError; + #[instrument(level = "debug", skip_all, err)] + async fn finalize(&mut self) -> Result<(), KeyExchangeError> { + let State::EqualityCheck { eq } = self.state.take() else { + return Err(KeyExchangeError::state( + "can not finalize before PMS is computed", + )); + }; - fn wants_flush(&self) -> bool { - !matches!( - self.state, - State::Initialized | State::Complete | State::Error - ) - } + let eq = eq.await.map_err(KeyExchangeError::vm)?; - async fn flush(&mut self, ctx: &mut Ctx) -> Result<(), Self::Error> { - if !self.wants_flush() { - return Ok(()); + if eq != [0u8; 32] { + return Err(KeyExchangeError::share_conversion("PMS values not equal")); } - self.state = match self.state.take() { - State::Setup { - share_a0, - share_b0, - share_a1, - share_b1, - eq, - } => { - let follower_key = match self.config.role() { - Role::Leader => ctx.io_mut().expect_next().await?, - Role::Follower => { - let follower_key = self.private_key.public_key(); - ctx.io_mut().send(follower_key).await?; - follower_key - } - }; - - State::FollowerKey { - follower_key, - share_a0, - share_b0, - share_a1, - share_b1, - eq, - } - } - State::FollowerKey { - share_a0, - share_b0, - share_a1, - share_b1, - eq, - .. - } => { - let server_key = match self.config.role() { - Role::Leader => { - let server_key = self - .server_key - .ok_or_else(|| KeyExchangeError::role("server key is not set"))?; - - ctx.io_mut().send(server_key).await?; - - server_key - } - Role::Follower => ctx.io_mut().expect_next().await?, - }; - - let (pms_0, pms_1) = compute_ec_shares( - ctx, - self.config.role(), - self.converter_0.clone(), - self.converter_1.clone(), - self.private_key.clone(), - server_key, - ) - .await?; - - State::ComputedECShares { - share_a0, - share_b0, - share_a1, - share_b1, - eq, - pms_0, - pms_1, - } - } - State::EqualityCheck { eq } => { - let eq = eq.await.map_err(KeyExchangeError::vm)?; - - if eq != [0u8; 32] { - return Err(KeyExchangeError::share_conversion("PMS values not equal")); - } - - State::Complete - } - state => state, - }; + self.state = State::Complete; Ok(()) } } -async fn compute_ec_shares( - ctx: &mut Ctx, +async fn compute_ec_shares( + ctx: &mut Context, role: Role, converter_0: Arc>, converter_1: Arc>, @@ -415,13 +402,8 @@ async fn compute_ec_shares( server_key: PublicKey, ) -> Result<(P256, P256), KeyExchangeError> where - Ctx: Context, - C0: ShareConvert + Flush + Send + 'static, - >::Future: Send, - >::Future: Send, - C1: ShareConvert + Flush + Send + 'static, - >::Future: Send, - >::Future: Send, + C0: ShareConvert + Flush + Send + 'static, + C1: ShareConvert + Flush + Send + 'static, { // Compute the leader's/follower's share of the pre-master secret. // @@ -457,7 +439,7 @@ where mod tests { use super::*; use crate::error::ErrorRepr; - use mpz_common::executor::test_st_executor; + use mpz_common::context::test_st_context; use mpz_core::Block; use mpz_garble::protocol::semihonest::{Evaluator, Generator}; use mpz_memory_core::correlated::Delta; @@ -467,9 +449,9 @@ mod tests { }; use mpz_vm_core::Execute; use p256::{NonZeroScalar, PublicKey, SecretKey}; - use rand::rngs::StdRng; - use rand_chacha::ChaCha12Rng; + use rand::{rngs::StdRng, Rng}; use rand_core::SeedableRng; + use rstest::*; impl MpcKeyExchange { fn set_pms_0(&mut self, pms: P256) { @@ -482,120 +464,73 @@ mod tests { #[tokio::test] async fn test_key_exchange() { - let mut rng = ChaCha12Rng::from_seed([0_u8; 32]); - let (mut ctx_a, mut ctx_b) = test_st_executor(8); + let mut rng = StdRng::seed_from_u64(0); + let (mut ctx_a, mut ctx_b) = test_st_context(8); let (mut gen, mut ev) = mock_vm(); let leader_private_key = SecretKey::random(&mut rng); let follower_private_key = SecretKey::random(&mut rng); let server_public_key = PublicKey::from_secret_scalar(&NonZeroScalar::random(&mut rng)); + let expected_client_public_key = PublicKey::from_affine( + (leader_private_key.public_key().to_projective() + + follower_private_key.public_key().to_projective()) + .to_affine(), + ) + .unwrap(); let (mut leader, mut follower) = create_pair(); - leader.alloc().unwrap(); - follower.alloc().unwrap(); + let leader_pms = leader.alloc(&mut gen).unwrap(); + let follower_pms = follower.alloc(&mut ev).unwrap(); + + let mut leader_pms = gen.decode(leader_pms).unwrap(); + let mut follower_pms = ev.decode(follower_pms).unwrap(); leader.private_key = leader_private_key.clone(); follower.private_key = follower_private_key.clone(); - leader.setup(&mut gen).unwrap(); - follower.setup(&mut ev).unwrap(); - - tokio::join!( + let (leader_pms, follower_pms) = tokio::join!( async { - leader.flush(&mut ctx_a).await.unwrap(); + leader.setup(&mut ctx_a).await.unwrap(); let client_public_key = leader.client_key().unwrap(); - leader.set_server_key(server_public_key).unwrap(); - - assert_eq!(leader.server_key().unwrap(), server_public_key); - - let expected_client_public_key = PublicKey::from_affine( - (leader_private_key.public_key().to_projective() - + follower_private_key.public_key().to_projective()) - .to_affine(), - ) - .unwrap(); - assert_eq!(client_public_key, expected_client_public_key); - leader.flush(&mut ctx_a).await.unwrap(); - }, - async { - follower.flush(&mut ctx_b).await.unwrap(); - follower.flush(&mut ctx_b).await.unwrap(); - } - ); - } - - #[tokio::test] - async fn test_compute_pms() { - let mut rng = ChaCha12Rng::from_seed([0_u8; 32]); - let (mut ctx_a, mut ctx_b) = test_st_executor(8); - let (mut gen, mut ev) = mock_vm(); - - let leader_private_key = SecretKey::random(&mut rng); - let follower_private_key = SecretKey::random(&mut rng); - let server_private_key = NonZeroScalar::random(&mut rng); - let server_public_key = PublicKey::from_secret_scalar(&server_private_key); - let (mut leader, mut follower) = create_pair(); - - leader.alloc().unwrap(); - follower.alloc().unwrap(); - - leader.private_key = leader_private_key.clone(); - follower.private_key = follower_private_key.clone(); - - let leader_pms = leader.setup(&mut gen).unwrap().into_value(); - let leader_pms = gen.decode(leader_pms).unwrap(); - - let follower_pms = follower.setup(&mut ev).unwrap().into_value(); - let follower_pms = ev.decode(follower_pms).unwrap(); - - tokio::join!( - async { - leader.flush(&mut ctx_a).await.unwrap(); - let _client_public_key = leader.client_key().unwrap(); leader.set_server_key(server_public_key).unwrap(); - leader.flush(&mut ctx_a).await.unwrap(); - }, - async { - follower.flush(&mut ctx_b).await.unwrap(); - follower.flush(&mut ctx_b).await.unwrap(); - } - ); - - leader.compute_pms(&mut gen).unwrap(); - follower.compute_pms(&mut ev).unwrap(); + leader.compute_shares(&mut ctx_a).await.unwrap(); + leader.assign(&mut gen).unwrap(); - tokio::join!( - async { gen.flush(&mut ctx_a).await.unwrap(); gen.execute(&mut ctx_a).await.unwrap(); - gen.flush(&mut ctx_a) - .await - .map_err(KeyExchangeError::vm) - .unwrap(); + gen.flush(&mut ctx_a).await.unwrap(); + + leader.finalize().await.unwrap(); + + leader_pms.try_recv().unwrap().unwrap() }, async { + follower.setup(&mut ctx_b).await.unwrap(); + follower.compute_shares(&mut ctx_b).await.unwrap(); + follower.assign(&mut ev).unwrap(); + ev.flush(&mut ctx_b).await.unwrap(); ev.execute(&mut ctx_b).await.unwrap(); - ev.flush(&mut ctx_b) - .await - .map_err(KeyExchangeError::vm) - .unwrap(); + ev.flush(&mut ctx_b).await.unwrap(); + + follower.finalize().await.unwrap(); + + follower_pms.try_recv().unwrap().unwrap() } ); - let (leader_pms, follower_pms) = tokio::try_join!(leader_pms, follower_pms).unwrap(); assert_eq!(leader_pms, follower_pms); } #[tokio::test] async fn test_compute_ec_shares() { - let mut rng = ChaCha12Rng::from_seed([0_u8; 32]); - let (mut ctx_leader, mut ctx_follower) = test_st_executor(8); + let mut rng = StdRng::seed_from_u64(0); + let (mut ctx_leader, mut ctx_follower) = test_st_context(8); let (leader_converter_0, follower_converter_0) = ideal_share_convert(Block::ZERO); let (follower_converter_1, leader_converter_1) = ideal_share_convert(Block::ZERO); @@ -653,79 +588,98 @@ mod tests { assert_ne!(leader_share_1, follower_share_1); } + enum Malicious { + Leader, + Follower, + } + + #[rstest] + #[case::malicious_leader(Malicious::Leader)] + #[case::malicious_follower(Malicious::Follower)] #[tokio::test] - async fn test_compute_pms_fail() { - let mut rng = ChaCha12Rng::from_seed([0_u8; 32]); - let (mut ctx_a, mut ctx_b) = test_st_executor(8); + async fn test_malicious_key_exchange(#[case] malicious: Malicious) { + let mut rng = StdRng::seed_from_u64(0); + let (mut ctx_a, mut ctx_b) = test_st_context(8); let (mut gen, mut ev) = mock_vm(); let leader_private_key = SecretKey::random(&mut rng); let follower_private_key = SecretKey::random(&mut rng); - let server_private_key = NonZeroScalar::random(&mut rng); - let server_public_key = PublicKey::from_secret_scalar(&server_private_key); + let server_public_key = PublicKey::from_secret_scalar(&NonZeroScalar::random(&mut rng)); + let expected_client_public_key = PublicKey::from_affine( + (leader_private_key.public_key().to_projective() + + follower_private_key.public_key().to_projective()) + .to_affine(), + ) + .unwrap(); let (mut leader, mut follower) = create_pair(); - leader.alloc().unwrap(); - follower.alloc().unwrap(); + leader.alloc(&mut gen).unwrap(); + follower.alloc(&mut ev).unwrap(); leader.private_key = leader_private_key.clone(); follower.private_key = follower_private_key.clone(); - leader.setup(&mut gen).unwrap(); - follower.setup(&mut ev).unwrap(); + let bad_pms_share: P256 = rng.gen(); - tokio::join!( + let (leader_err, follower_err) = tokio::join!( async { - leader.flush(&mut ctx_a).await.unwrap(); - let _client_public_key = leader.client_key().unwrap(); + leader.setup(&mut ctx_a).await.unwrap(); + + let client_public_key = leader.client_key().unwrap(); + + assert_eq!(client_public_key, expected_client_public_key); + leader.set_server_key(server_public_key).unwrap(); - leader.flush(&mut ctx_a).await.unwrap(); - }, - async { - follower.flush(&mut ctx_b).await.unwrap(); - follower.flush(&mut ctx_b).await.unwrap(); - } - ); + leader.compute_shares(&mut ctx_a).await.unwrap(); - // Now manipulate pms - leader.set_pms_0(P256::one()); - follower.set_pms_0(P256::one()); + // Replace the leader's share with a different value. + if let Malicious::Leader = malicious { + leader.set_pms_0(bad_pms_share.clone()); + } - leader.compute_pms(&mut gen).unwrap(); - follower.compute_pms(&mut ev).unwrap(); + leader.assign(&mut gen).unwrap(); - let (leader_res, follower_res) = tokio::join!( - async { gen.flush(&mut ctx_a).await.unwrap(); gen.execute(&mut ctx_a).await.unwrap(); - gen.flush(&mut ctx_a) - .await - .map_err(KeyExchangeError::vm) - .unwrap(); - leader.flush(&mut ctx_a).await + gen.flush(&mut ctx_a).await.unwrap(); + + leader.finalize().await }, async { + follower.setup(&mut ctx_b).await.unwrap(); + follower.compute_shares(&mut ctx_b).await.unwrap(); + + // Replace the follower's share with a different value. + if let Malicious::Follower = malicious { + follower.set_pms_0(bad_pms_share.clone()); + } + + follower.assign(&mut ev).unwrap(); + ev.flush(&mut ctx_b).await.unwrap(); ev.execute(&mut ctx_b).await.unwrap(); - ev.flush(&mut ctx_b) - .await - .map_err(KeyExchangeError::vm) - .unwrap(); - follower.flush(&mut ctx_b).await + ev.flush(&mut ctx_b).await.unwrap(); + + follower.finalize().await } ); - let leader_err = leader_res.unwrap_err(); - let follower_err = follower_res.unwrap_err(); - - assert!(matches!(leader_err.kind(), ErrorRepr::ShareConversion(_))); - assert!(matches!(follower_err.kind(), ErrorRepr::ShareConversion(_))); + match malicious { + Malicious::Leader => assert!(matches!( + follower_err.unwrap_err().0, + ErrorRepr::ShareConversion(_) + )), + Malicious::Follower => assert!(matches!( + leader_err.unwrap_err().0, + ErrorRepr::ShareConversion(_) + )), + } } #[tokio::test] async fn test_circuit() { - let (mut ctx_a, mut ctx_b) = test_st_executor(8); + let (mut ctx_a, mut ctx_b) = test_st_context(8); let (gen, ev) = mock_vm(); let share_a0_bytes = [5_u8; 32]; diff --git a/crates/components/key-exchange/src/lib.rs b/crates/components/key-exchange/src/lib.rs index e1c824db1..4df2df141 100644 --- a/crates/components/key-exchange/src/lib.rs +++ b/crates/components/key-exchange/src/lib.rs @@ -22,39 +22,29 @@ mod exchange; pub mod mock; pub(crate) mod point_addition; +use async_trait::async_trait; pub use config::{ KeyExchangeConfig, KeyExchangeConfigBuilder, KeyExchangeConfigBuilderError, Role, }; pub use error::KeyExchangeError; pub use exchange::MpcKeyExchange; +use mpz_common::Context; use mpz_memory_core::{ binary::{Binary, U8}, - Array, Memory, View, + Array, }; use mpz_vm_core::Vm; use p256::PublicKey; /// Pre-master secret. -#[derive(Debug, Clone, Copy)] -pub struct Pms(Array); - -impl Pms { - /// Creates a new PMS. - pub fn new(pms: Array) -> Self { - Self(pms) - } - - /// Gets the value of the PMS. - pub fn into_value(self) -> Array { - self.0 - } -} +pub type Pms = Array; /// A trait for the 3-party key exchange protocol. +#[async_trait] pub trait KeyExchange { /// Allocate necessary computational resources. - fn alloc(&mut self) -> Result<(), KeyExchangeError>; + fn alloc(&mut self, vm: &mut dyn Vm) -> Result; /// Sets the server's public key. fn set_server_key(&mut self, server_key: PublicKey) -> Result<(), KeyExchangeError>; @@ -69,16 +59,15 @@ pub trait KeyExchange { /// key. fn client_key(&self) -> Result; - /// Performs any necessary one-time setup, returning a reference to the PMS. - fn setup(&mut self, vm: &mut V) -> Result - where - V: Vm + Memory + View; + /// Performs one-time setup for the key exchange protocol. + async fn setup(&mut self, ctx: &mut Context) -> Result<(), KeyExchangeError>; - /// Computes the PMS, and returns an equality check. - /// - /// The equality check makes sure that both parties arrived at the same - /// result. This MUST be called to prevent malicious behavior! - fn compute_pms(&mut self, vm: &mut V) -> Result<(), KeyExchangeError> - where - V: Vm + Memory + View; + /// Computes the shares of the PMS. + async fn compute_shares(&mut self, ctx: &mut Context) -> Result<(), KeyExchangeError>; + + /// Assigns the PMS shares to the VM. + fn assign(&mut self, vm: &mut dyn Vm) -> Result<(), KeyExchangeError>; + + /// Finalizes the key exchange protocol. + async fn finalize(&mut self) -> Result<(), KeyExchangeError>; } diff --git a/crates/components/key-exchange/src/mock.rs b/crates/components/key-exchange/src/mock.rs index f525f7056..02cb95647 100644 --- a/crates/components/key-exchange/src/mock.rs +++ b/crates/components/key-exchange/src/mock.rs @@ -44,7 +44,6 @@ pub fn create_mock_key_exchange_pair() -> (MockKeyExchange, MockKeyExchange) { #[cfg(test)] mod tests { - use mpz_common::executor::TestSTExecutor; use mpz_garble::protocol::semihonest::{Evaluator, Generator}; use mpz_ot::ideal::cot::{IdealCOTReceiver, IdealCOTSender}; @@ -55,17 +54,15 @@ mod tests { fn test_mock_is_ke() { let (leader, follower) = create_mock_key_exchange_pair(); - fn is_key_exchange(_: T) {} + fn is_key_exchange(_: T) {} is_key_exchange::< MpcKeyExchange, IdealShareConvertReceiver>, - TestSTExecutor, Generator, >(leader); is_key_exchange::< MpcKeyExchange, IdealShareConvertReceiver>, - TestSTExecutor, Evaluator, >(follower); } diff --git a/crates/components/key-exchange/src/point_addition.rs b/crates/components/key-exchange/src/point_addition.rs index b3670ffbb..f24cca24e 100644 --- a/crates/components/key-exchange/src/point_addition.rs +++ b/crates/components/key-exchange/src/point_addition.rs @@ -1,6 +1,7 @@ -//! This module implements a secure two-party computation protocol for adding two private EC points -//! and secret-sharing the resulting x coordinate (the shares are field elements of the field -//! underlying the elliptic curve). This protocol has semi-honest security. +//! This module implements a secure two-party computation protocol for adding +//! two private EC points and secret-sharing the resulting x coordinate (the +//! shares are field elements of the field underlying the elliptic curve). This +//! protocol has semi-honest security. //! //! The protocol is described in //! @@ -12,15 +13,14 @@ use mpz_share_conversion::{AdditiveToMultiplicative, MultiplicativeToAdditive, S use p256::EncodedPoint; /// Derives the x-coordinate share of an elliptic curve point. -pub(crate) async fn derive_x_coord_share( - ctx: &mut Ctx, +pub(crate) async fn derive_x_coord_share( + ctx: &mut Context, role: Role, converter: &mut C, share: EncodedPoint, ) -> Result where - Ctx: Context, - C: ShareConvert + Flush + Send, + C: ShareConvert + Flush + Send, >::Future: Send, >::Future: Send, { @@ -98,7 +98,7 @@ fn decompose_point(point: EncodedPoint) -> Result<[P256; 2], KeyExchangeError> { mod tests { use super::*; - use mpz_common::executor::test_st_executor; + use mpz_common::context::test_st_context; use mpz_core::Block; use mpz_fields::{p256::P256, Field}; use mpz_share_conversion::ideal::ideal_share_convert; @@ -106,13 +106,12 @@ mod tests { elliptic_curve::sec1::{FromEncodedPoint, ToEncodedPoint}, EncodedPoint, NonZeroScalar, ProjectivePoint, PublicKey, }; - use rand::{Rng, SeedableRng}; - use rand_chacha::ChaCha12Rng; + use rand::{rngs::StdRng, Rng, SeedableRng}; #[tokio::test] async fn test_point_addition() { - let (mut ctx_a, mut ctx_b) = test_st_executor(8); - let mut rng = ChaCha12Rng::from_seed([0u8; 32]); + let (mut ctx_a, mut ctx_b) = test_st_context(8); + let mut rng = StdRng::seed_from_u64(0); let p1: [u8; 32] = rng.gen(); let p2: [u8; 32] = rng.gen(); @@ -137,7 +136,7 @@ mod tests { #[test] fn test_decompose_point() { - let mut rng = ChaCha12Rng::from_seed([0_u8; 32]); + let mut rng = StdRng::seed_from_u64(0); let p_expected: [u8; 32] = rng.gen(); let p_expected = curve_point_from_be_bytes(p_expected);