From 87595005b70064d8b4a83acae508cf6d1993da82 Mon Sep 17 00:00:00 2001 From: th4s Date: Wed, 15 Jan 2025 13:23:44 +0100 Subject: [PATCH 1/5] refactor(key-exchange): adapt key-exchange to new vm --- crates/components/key-exchange/Cargo.toml | 21 +- crates/components/key-exchange/src/circuit.rs | 3 +- crates/components/key-exchange/src/config.rs | 4 +- crates/components/key-exchange/src/error.rs | 164 ++- .../components/key-exchange/src/exchange.rs | 1080 +++++++++++------ crates/components/key-exchange/src/lib.rs | 62 +- crates/components/key-exchange/src/mock.rs | 55 +- .../key-exchange/src/point_addition.rs | 76 +- 8 files changed, 900 insertions(+), 565 deletions(-) diff --git a/crates/components/key-exchange/Cargo.toml b/crates/components/key-exchange/Cargo.toml index 1bc93237d..21cd9679a 100644 --- a/crates/components/key-exchange/Cargo.toml +++ b/crates/components/key-exchange/Cargo.toml @@ -16,28 +16,25 @@ default = ["mock"] mock = [] [dependencies] -mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" } -mpz-common = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" } -mpz-fields = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" } -mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac", features = [ - "ideal", -] } -mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac" } +mpz-vm-core = { workspace = true } +mpz-memory-core = { workspace = true } +mpz-common = { workspace = true } +mpz-fields = { workspace = true } +mpz-share-conversion = { workspace = true, features = ["test-utils"] } +mpz-circuits = { workspace = true } +mpz-core = { workspace = true } p256 = { workspace = true, features = ["ecdh", "serde"] } async-trait = { workspace = true } thiserror = { workspace = true } -serde = { workspace = true } -futures = { workspace = true } serio = { workspace = true } derive_builder = { workspace = true } tracing = { workspace = true } rand = { workspace = true } [dev-dependencies] -mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "b8ae7ac", features = [ - "ideal", -] } +mpz-ot = { workspace = true , features = ["ideal"] } +mpz-garble = { workspace = true } rand_chacha = { workspace = true } rand_core = { workspace = true } diff --git a/crates/components/key-exchange/src/circuit.rs b/crates/components/key-exchange/src/circuit.rs index 28fdc4271..8036c491b 100644 --- a/crates/components/key-exchange/src/circuit.rs +++ b/crates/components/key-exchange/src/circuit.rs @@ -1,8 +1,7 @@ //! This module provides the circuits used in the key exchange protocol. -use std::sync::Arc; - use mpz_circuits::{circuits::big_num::nbyte_add_mod_trace, Circuit, CircuitBuilder}; +use std::sync::Arc; /// NIST P-256 prime big-endian. static P: [u8; 32] = [ diff --git a/crates/components/key-exchange/src/config.rs b/crates/components/key-exchange/src/config.rs index 5e32f55c7..4b25ea12a 100644 --- a/crates/components/key-exchange/src/config.rs +++ b/crates/components/key-exchange/src/config.rs @@ -23,7 +23,7 @@ impl KeyExchangeConfig { } /// Get the role of this instance. - pub fn role(&self) -> &Role { - &self.role + pub fn role(&self) -> Role { + self.role } } diff --git a/crates/components/key-exchange/src/error.rs b/crates/components/key-exchange/src/error.rs index cf9e7da4d..50a06c8ee 100644 --- a/crates/components/key-exchange/src/error.rs +++ b/crates/components/key-exchange/src/error.rs @@ -1,120 +1,112 @@ -use core::fmt; -use std::error::Error; +use std::{error::Error, fmt::Display}; -/// A key exchange error. +/// MPC-TLS protocol error. #[derive(Debug, thiserror::Error)] -pub struct KeyExchangeError { - kind: ErrorKind, - #[source] - source: Option>, +#[error(transparent)] +pub struct KeyExchangeError(#[from] ErrorRepr); + +#[derive(Debug, thiserror::Error)] +pub(crate) enum ErrorRepr { + /// An unexpected state was encountered + State(Box), + /// Context error. + Ctx(Box), + /// IO related error + Io(Box), + /// Virtual machine error + Vm(Box), + /// Share conversion error + ShareConversion(Box), + /// Role error + Role(Box), + /// Key error + Key(Box), } -impl KeyExchangeError { - pub(crate) fn new(kind: ErrorKind, source: E) -> Self - where - E: Into>, - { - Self { - kind, - source: Some(source.into()), +impl Display for ErrorRepr { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ErrorRepr::State(error) => write!(f, "{error}"), + ErrorRepr::Ctx(error) => write!(f, "{error}"), + ErrorRepr::Io(error) => write!(f, "{error}"), + ErrorRepr::Vm(error) => write!(f, "{error}"), + ErrorRepr::ShareConversion(error) => write!(f, "{error}"), + ErrorRepr::Role(error) => write!(f, "{error}"), + ErrorRepr::Key(error) => write!(f, "{error}"), } } +} - #[cfg(test)] - pub(crate) fn kind(&self) -> &ErrorKind { - &self.kind - } - - pub(crate) fn state(msg: impl Into) -> Self { - Self { - kind: ErrorKind::State, - source: Some(msg.into().into()), - } +impl KeyExchangeError { + pub(crate) fn state(err: E) -> KeyExchangeError + where + E: Into>, + { + Self(ErrorRepr::State(err.into())) } - pub(crate) fn role(msg: impl Into) -> Self { - Self { - kind: ErrorKind::Role, - source: Some(msg.into().into()), - } + pub(crate) fn ctx(err: E) -> KeyExchangeError + where + E: Into>, + { + Self(ErrorRepr::Ctx(err.into())) } -} - -#[derive(Debug)] -pub(crate) enum ErrorKind { - Io, - Context, - Vm, - ShareConversion, - Key, - State, - Role, -} - -impl fmt::Display for KeyExchangeError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self.kind { - ErrorKind::Io => write!(f, "io error")?, - ErrorKind::Context => write!(f, "context error")?, - ErrorKind::Vm => write!(f, "vm error")?, - ErrorKind::ShareConversion => write!(f, "share conversion error")?, - ErrorKind::Key => write!(f, "key error")?, - ErrorKind::State => write!(f, "state error")?, - ErrorKind::Role => write!(f, "role error")?, - } - if let Some(ref source) = self.source { - write!(f, " caused by: {}", source)?; - } - - Ok(()) + pub(crate) fn io(err: E) -> KeyExchangeError + where + E: Into>, + { + Self(ErrorRepr::Io(err.into())) } -} -impl From for KeyExchangeError { - fn from(error: mpz_common::ContextError) -> Self { - Self::new(ErrorKind::Context, error) + pub(crate) fn vm(err: E) -> KeyExchangeError + where + E: Into>, + { + Self(ErrorRepr::Vm(err.into())) } -} -impl From for KeyExchangeError { - fn from(error: mpz_garble::MemoryError) -> Self { - Self::new(ErrorKind::Vm, error) + pub(crate) fn share_conversion(err: E) -> KeyExchangeError + where + E: Into>, + { + Self(ErrorRepr::ShareConversion(err.into())) } -} -impl From for KeyExchangeError { - fn from(error: mpz_garble::LoadError) -> Self { - Self::new(ErrorKind::Vm, error) + pub(crate) fn role(err: E) -> KeyExchangeError + where + E: Into>, + { + Self(ErrorRepr::Role(err.into())) } -} -impl From for KeyExchangeError { - fn from(error: mpz_garble::ExecutionError) -> Self { - Self::new(ErrorKind::Vm, error) + pub(crate) fn key(err: E) -> KeyExchangeError + where + E: Into>, + { + Self(ErrorRepr::Key(err.into())) } -} -impl From for KeyExchangeError { - fn from(error: mpz_garble::DecodeError) -> Self { - Self::new(ErrorKind::Vm, error) + #[cfg(test)] + pub(crate) fn kind(&self) -> &ErrorRepr { + &self.0 } } -impl From for KeyExchangeError { - fn from(error: mpz_share_conversion::ShareConversionError) -> Self { - Self::new(ErrorKind::ShareConversion, error) +impl From for KeyExchangeError { + fn from(value: mpz_common::ContextError) -> Self { + Self::ctx(value) } } impl From for KeyExchangeError { - fn from(error: p256::elliptic_curve::Error) -> Self { - Self::new(ErrorKind::Key, error) + fn from(value: p256::elliptic_curve::Error) -> Self { + Self::key(value) } } impl From for KeyExchangeError { - fn from(error: std::io::Error) -> Self { - Self::new(ErrorKind::Io, error) + fn from(value: std::io::Error) -> Self { + Self::io(value) } } diff --git a/crates/components/key-exchange/src/exchange.rs b/crates/components/key-exchange/src/exchange.rs index 76cc53047..e98c073c7 100644 --- a/crates/components/key-exchange/src/exchange.rs +++ b/crates/components/key-exchange/src/exchange.rs @@ -1,448 +1,511 @@ //! This module implements the key exchange logic. use async_trait::async_trait; -use mpz_common::{scoped_futures::ScopedFutureExt, Allocate, Context, Preprocess}; -use mpz_garble::{value::ValueRef, Decode, Execute, Load, Memory}; +use mpz_common::{Context, Flush}; use mpz_fields::{p256::P256, Field}; -use mpz_share_conversion::{ShareConversionError, ShareConvert}; +use mpz_memory_core::{ + binary::{Binary, U8}, + Array, Memory, MemoryExt, View, ViewExt, +}; +use mpz_share_conversion::{AdditiveToMultiplicative, MultiplicativeToAdditive, ShareConvert}; +use mpz_vm_core::{CallBuilder, Vm, VmExt}; use p256::{EncodedPoint, PublicKey, SecretKey}; -use serio::{stream::IoStreamExt, SinkExt}; +use serio::{sink::SinkExt, stream::IoStreamExt}; use std::fmt::Debug; -use tracing::{debug, instrument}; +use tracing::instrument; use crate::{ circuit::build_pms_circuit, config::{KeyExchangeConfig, Role}, - error::ErrorKind, point_addition::derive_x_coord_share, - KeyExchange, KeyExchangeError, Pms, + EqualityCheck, KeyExchange, KeyExchangeError, Pms, }; #[derive(Debug)] enum State { - Initialized, + Initialized { + /// The private key of the party behind this instance, either follower or leader. + private_key: SecretKey, + }, Setup { - share_a0: ValueRef, - share_b0: ValueRef, - share_a1: ValueRef, - share_b1: ValueRef, - pms_0: ValueRef, - pms_1: ValueRef, - eq: ValueRef, + private_key: SecretKey, + share_a0: Array, + share_b0: Array, + share_a1: Array, + share_b1: Array, + eq: Array, + }, + SetFollowerKey { + private_key: SecretKey, + /// The public key of the follower + follower_key: PublicKey, + share_a0: Array, + share_b0: Array, + share_a1: Array, + share_b1: Array, + eq: Array, }, - Preprocessed { - share_a0: ValueRef, - share_b0: ValueRef, - share_a1: ValueRef, - share_b1: ValueRef, - pms_0: ValueRef, - pms_1: ValueRef, - eq: ValueRef, + SetAllKeys { + private_key: SecretKey, + /// The public key of the server. + server_key: PublicKey, + share_a0: Array, + share_b0: Array, + share_a1: Array, + share_b1: Array, + eq: Array, + }, + ComputedECShares { + server_key: PublicKey, + share_a0: Array, + share_b0: Array, + share_a1: Array, + share_b1: Array, + eq: Array, + pms_0: P256, + pms_1: P256, }, Complete, Error, } -impl State { - fn is_preprocessed(&self) -> bool { - matches!(self, Self::Preprocessed { .. }) - } - - fn take(&mut self) -> Self { - std::mem::replace(self, Self::Error) - } -} - /// An MPC key exchange protocol. /// /// Can be either a leader or a follower depending on the `role` field in /// [`KeyExchangeConfig`]. #[derive(Debug)] -pub struct MpcKeyExchange { - ctx: Ctx, +pub struct MpcKeyExchange { /// Share conversion protocol 0. converter_0: C0, /// Share conversion protocol 1. converter_1: C1, - /// MPC executor. - executor: E, - /// The private key of the party behind this instance, either follower or - /// leader. - private_key: Option, - /// The public key of the server. - server_key: Option, /// The config used for the key exchange protocol. config: KeyExchangeConfig, /// The state of the protocol. state: State, } -impl MpcKeyExchange { +impl MpcKeyExchange { /// Creates a new [`MpcKeyExchange`]. /// /// # Arguments /// /// * `config` - Key exchange configuration. - /// * `ctx` - Thread context. /// * `converter_0` - Share conversion protocol instance 0. /// * `converter_1` - Share conversion protocol instance 1. - /// * `executor` - MPC executor. - pub fn new( - config: KeyExchangeConfig, - ctx: Ctx, - converter_0: C0, - converter_1: C1, - executor: E, - ) -> Self { + pub fn new(config: KeyExchangeConfig, converter_0: C0, converter_1: C1) -> Self { + let private_key = SecretKey::random(&mut rand::rngs::OsRng); + Self { - ctx, converter_0, converter_1, - executor, - private_key: None, - server_key: None, config, - state: State::Initialized, + state: State::Initialized { private_key }, } } -} -impl MpcKeyExchange -where - Ctx: Context, - E: Execute + Load + Memory + Decode + Send, - C0: ShareConvert + Send, - C1: ShareConvert + Send, -{ - async fn compute_pms_shares( - &mut self, - server_key: PublicKey, - private_key: SecretKey, - ) -> Result<(P256, P256), KeyExchangeError> { - compute_pms_shares( - &mut self.ctx, - *self.config.role(), + async fn compute_ec_shares(&mut self, ctx: &mut Ctx) -> Result<(), KeyExchangeError> + where + Ctx: Context, + C0: ShareConvert + Flush + Send, + >::Future: Send, + >::Future: Send, + C1: ShareConvert + Flush + Send, + >::Future: Send, + >::Future: Send, + { + let State::SetAllKeys { + private_key, + server_key, + share_a0, + share_b0, + share_a1, + share_b1, + eq, + .. + } = std::mem::replace(&mut self.state, State::Error) + else { + return Err(KeyExchangeError::state( + "should be in SetAllKeys state to compute pms", + )); + }; + let (pms_0, pms_1) = compute_ec_shares( + ctx, + self.config.role(), &mut self.converter_0, &mut self.converter_1, - server_key, private_key, + server_key, ) - .await - } + .await?; - // Computes the PMS using both parties' shares, performing an equality check - // to ensure the shares are equal. - async fn compute_pms_with( - &mut self, - share_0: P256, - share_1: P256, - ) -> Result { - let State::Preprocessed { + self.state = State::ComputedECShares { + server_key, share_a0, share_b0, share_a1, share_b1, + eq, pms_0, pms_1, - eq, - } = self.state.take() - else { - return Err(KeyExchangeError::state("not in preprocessed state")); }; + Ok(()) + } +} - let share_0_bytes: [u8; 32] = share_0 - .to_be_bytes() - .try_into() - .expect("pms share is 32 bytes"); - let share_1_bytes: [u8; 32] = share_1 - .to_be_bytes() - .try_into() - .expect("pms share is 32 bytes"); +impl KeyExchange for MpcKeyExchange +where + V: Vm + Memory + View + Send, + C0: ShareConvert + Send, + C1: ShareConvert + Send, +{ + fn alloc(&mut self) -> Result<(), KeyExchangeError> { + // 2 A2M, 1 M2A. + >::alloc(&mut self.converter_0, 1) + .map_err(KeyExchangeError::share_conversion)?; + >::alloc(&mut self.converter_1, 1) + .map_err(KeyExchangeError::share_conversion)?; - match self.config.role() { - Role::Leader => { - self.executor.assign(&share_a0, share_0_bytes)?; - self.executor.assign(&share_a1, share_1_bytes)?; - } - Role::Follower => { - self.executor.assign(&share_b0, share_0_bytes)?; - self.executor.assign(&share_b1, share_1_bytes)?; - } - } + >::alloc(&mut self.converter_0, 2) + .map_err(KeyExchangeError::share_conversion)?; + >::alloc(&mut self.converter_1, 2) + .map_err(KeyExchangeError::share_conversion)?; - self.executor - .execute( - build_pms_circuit(), - &[share_a0, share_b0, share_a1, share_b1], - &[pms_0.clone(), pms_1, eq.clone()], - ) - .await?; - - let eq: [u8; 32] = self - .executor - .decode(&[eq]) - .await? - .pop() - .expect("output 0 is eq") - .try_into() - .expect("eq is 32 bytes"); + Ok(()) + } + + fn set_server_key(&mut self, server_key: PublicKey) -> Result<(), KeyExchangeError> { + let Role::Leader = self.config.role() else { + return Err(KeyExchangeError::role("follower cannot set server key")); + }; - // Eq should be all zeros if pms_1 == pms_2. - if eq != [0u8; 32] { - return Err(KeyExchangeError::new( - ErrorKind::ShareConversion, - "PMS values not equal", + let State::SetFollowerKey { + private_key, + share_a0, + share_b0, + share_a1, + share_b1, + eq, + .. + } = std::mem::replace(&mut self.state, State::Error) + else { + return Err(KeyExchangeError::state( + "leader must be in SetFollowerKey state to set the server key", )); - } + }; - // Both parties use pms_0 as the pre-master secret. - Ok(Pms::new(pms_0)) + self.state = State::SetAllKeys { + private_key, + server_key, + share_a0, + share_b0, + share_a1, + share_b1, + eq, + }; + Ok(()) } -} -#[async_trait] -impl KeyExchange for MpcKeyExchange -where - Ctx: Context, - E: Execute + Load + Memory + Decode + Send, - C0: Allocate + Preprocess + ShareConvert + Send, - C1: Allocate + Preprocess + ShareConvert + Send, -{ fn server_key(&self) -> Option { - self.server_key + match self.state { + State::SetAllKeys { server_key, .. } => Some(server_key), + State::ComputedECShares { server_key, .. } => Some(server_key), + _ => None, + } } - async fn set_server_key(&mut self, server_key: PublicKey) -> Result<(), KeyExchangeError> { + #[instrument(level = "debug", skip_all, err)] + fn client_key(&self) -> Result { let Role::Leader = self.config.role() else { - return Err(KeyExchangeError::role("follower cannot set server key")); + return Err(KeyExchangeError::role("follower does not learn client key")); }; - // Send server public key to follower. - self.ctx.io_mut().send(server_key).await?; + let State::SetFollowerKey { + private_key, + follower_key, + .. + } = &self.state + else { + return Err(KeyExchangeError::state( + "leader should be in SetFollowerKey state for returning the client key", + )); + }; - self.server_key = Some(server_key); + let public_key = private_key.public_key(); - Ok(()) + // Combine public keys. + let client_public_key = PublicKey::from_affine( + (public_key.to_projective() + follower_key.to_projective()).to_affine(), + )?; + + Ok(client_public_key) } #[instrument(level = "debug", skip_all, err)] - async fn setup(&mut self) -> Result { - let State::Initialized = self.state.take() else { - return Err(KeyExchangeError::state("not in initialized state")); + fn setup(&mut self, vm: &mut V) -> Result { + let State::Initialized { private_key } = std::mem::replace(&mut self.state, State::Error) + else { + return Err(KeyExchangeError::state( + "should be in Initialized state to call setup", + )); }; - - // 2 A2M, 1 M2A. - self.converter_0.alloc(3); - self.converter_1.alloc(3); - let (share_a0, share_b0, share_a1, share_b1) = match self.config.role() { Role::Leader => { - let share_a0 = self - .executor - .new_private_input::<[u8; 32]>("pms/share_a0")?; - let share_b0 = self.executor.new_blind_input::<[u8; 32]>("pms/share_b0")?; - let share_a1 = self - .executor - .new_private_input::<[u8; 32]>("pms/share_a1")?; - let share_b1 = self.executor.new_blind_input::<[u8; 32]>("pms/share_b1")?; + let share_a0: Array = vm.alloc().map_err(KeyExchangeError::vm)?; + vm.mark_private(share_a0).map_err(KeyExchangeError::vm)?; + + let share_b0: Array = vm.alloc().map_err(KeyExchangeError::vm)?; + vm.mark_blind(share_b0).map_err(KeyExchangeError::vm)?; + + let share_a1: Array = vm.alloc().map_err(KeyExchangeError::vm)?; + vm.mark_private(share_a1).map_err(KeyExchangeError::vm)?; + + let share_b1: Array = vm.alloc().map_err(KeyExchangeError::vm)?; + vm.mark_blind(share_b1).map_err(KeyExchangeError::vm)?; (share_a0, share_b0, share_a1, share_b1) } Role::Follower => { - let share_a0 = self.executor.new_blind_input::<[u8; 32]>("pms/share_a0")?; - let share_b0 = self - .executor - .new_private_input::<[u8; 32]>("pms/share_b0")?; - let share_a1 = self.executor.new_blind_input::<[u8; 32]>("pms/share_a1")?; - let share_b1 = self - .executor - .new_private_input::<[u8; 32]>("pms/share_b1")?; + let share_a0: Array = vm.alloc().map_err(KeyExchangeError::vm)?; + vm.mark_blind(share_a0).map_err(KeyExchangeError::vm)?; + + let share_b0: Array = vm.alloc().map_err(KeyExchangeError::vm)?; + vm.mark_private(share_b0).map_err(KeyExchangeError::vm)?; + + let share_a1: Array = vm.alloc().map_err(KeyExchangeError::vm)?; + vm.mark_blind(share_a1).map_err(KeyExchangeError::vm)?; + + let share_b1: Array = vm.alloc().map_err(KeyExchangeError::vm)?; + vm.mark_private(share_b1).map_err(KeyExchangeError::vm)?; (share_a0, share_b0, share_a1, share_b1) } }; - let pms_0 = self.executor.new_output::<[u8; 32]>("pms_0")?; - let pms_1 = self.executor.new_output::<[u8; 32]>("pms_1")?; - let eq = self.executor.new_output::<[u8; 32]>("eq")?; + let pms_circuit = build_pms_circuit(); + let pms_call = CallBuilder::new(pms_circuit) + .arg(share_a0) + .arg(share_b0) + .arg(share_a1) + .arg(share_b1) + .build() + .map_err(KeyExchangeError::vm)?; + + let (pms, _, eq): (Array, Array, Array) = + vm.call(pms_call).map_err(KeyExchangeError::vm)?; self.state = State::Setup { share_a0, share_b0, share_a1, share_b1, - pms_0: pms_0.clone(), - pms_1, eq, + private_key, }; - Ok(Pms::new(pms_0)) + Ok(Pms::new(pms)) } #[instrument(level = "debug", skip_all, err)] - async fn preprocess(&mut self) -> Result<(), KeyExchangeError> { - let State::Setup { + fn compute_pms(&mut self, vm: &mut V) -> Result { + let State::ComputedECShares { share_a0, share_b0, share_a1, share_b1, + eq, pms_0, pms_1, - eq, - } = self.state.take() + .. + } = std::mem::replace(&mut self.state, State::Error) else { - return Err(KeyExchangeError::state("not in setup state")); + return Err(KeyExchangeError::state( + "should be in ComputedECShares state to compute pms", + )); }; - // Preprocess share conversion and garbled circuits concurrently. - futures::try_join!( - async { - self.ctx - .try_join( - |ctx| self.converter_0.preprocess(ctx).scope_boxed(), - |ctx| self.converter_1.preprocess(ctx).scope_boxed(), - ) - .await??; - - Ok::<_, KeyExchangeError>(()) - }, - async { - self.executor - .load( - build_pms_circuit(), - &[ - share_a0.clone(), - share_b0.clone(), - share_a1.clone(), - share_b1.clone(), - ], - &[pms_0.clone(), pms_1.clone(), eq.clone()], - ) - .await?; - - Ok::<_, KeyExchangeError>(()) - } - )?; - - // Follower can forward their key share immediately. - if let Role::Follower = self.config.role() { - let private_key = self - .private_key - .get_or_insert_with(|| SecretKey::random(&mut rand::rngs::OsRng)); + let share_0_bytes: [u8; 32] = pms_0 + .to_be_bytes() + .try_into() + .expect("pms share is 32 bytes"); + let share_1_bytes: [u8; 32] = pms_1 + .to_be_bytes() + .try_into() + .expect("pms share is 32 bytes"); - self.ctx.io_mut().send(private_key.public_key()).await?; + match self.config.role() { + Role::Leader => { + vm.assign(share_a0, share_0_bytes) + .map_err(KeyExchangeError::vm)?; + vm.commit(share_a0).map_err(KeyExchangeError::vm)?; - debug!("sent public key share to leader"); - } + vm.assign(share_a1, share_1_bytes) + .map_err(KeyExchangeError::vm)?; + vm.commit(share_a1).map_err(KeyExchangeError::vm)?; - self.state = State::Preprocessed { - share_a0, - share_b0, - share_a1, - share_b1, - pms_0, - pms_1, - eq, - }; + vm.commit(share_b0).map_err(KeyExchangeError::vm)?; + vm.commit(share_b1).map_err(KeyExchangeError::vm)?; + } + Role::Follower => { + vm.assign(share_b0, share_0_bytes) + .map_err(KeyExchangeError::vm)?; + vm.commit(share_b0).map_err(KeyExchangeError::vm)?; - Ok(()) - } + vm.assign(share_b1, share_1_bytes) + .map_err(KeyExchangeError::vm)?; + vm.commit(share_b1).map_err(KeyExchangeError::vm)?; - #[instrument(level = "debug", skip_all, err)] - async fn client_key(&mut self) -> Result { - if let Role::Leader = self.config.role() { - let private_key = self - .private_key - .get_or_insert_with(|| SecretKey::random(&mut rand::rngs::OsRng)); - let public_key = private_key.public_key(); + vm.commit(share_a0).map_err(KeyExchangeError::vm)?; + vm.commit(share_a1).map_err(KeyExchangeError::vm)?; + } + } - // Receive public key share from follower. - let follower_public_key: PublicKey = self.ctx.io_mut().expect_next().await?; + let check = vm.decode(eq).map_err(KeyExchangeError::vm)?; + let check = EqualityCheck(check); - debug!("received public key share from follower"); + self.state = State::Complete; + Ok(check) + } +} - // Combine public keys. - let client_public_key = PublicKey::from_affine( - (public_key.to_projective() + follower_public_key.to_projective()).to_affine(), - )?; +#[async_trait] +impl Flush for MpcKeyExchange +where + Ctx: Context, + C0: ShareConvert + Flush + Send, + >::Future: Send, + >::Future: Send, + C1: ShareConvert + Flush + Send, + >::Future: Send, + >::Future: Send, +{ + type Error = KeyExchangeError; - Ok(client_public_key) + fn wants_flush(&self) -> bool { + if let Role::Leader = self.config.role() { + matches!(self.state, State::Setup { .. } | State::SetAllKeys { .. }) } else { - Err(KeyExchangeError::role("follower does not learn client key")) + matches!(self.state, State::Setup { .. }) } } - #[instrument(level = "debug", skip_all, err)] - async fn compute_pms(&mut self) -> Result { - if !self.state.is_preprocessed() { - return Err(KeyExchangeError::state("not in preprocessed state")); - } - - let server_key = match self.config.role() { - Role::Leader => self - .server_key - .ok_or_else(|| KeyExchangeError::state("server public key not set"))?, - Role::Follower => { - // Receive server public key from leader. - let server_key = self.ctx.io_mut().expect_next().await?; - - self.server_key = Some(server_key); - - server_key + async fn flush(&mut self, ctx: &mut Ctx) -> Result<(), Self::Error> { + if let Role::Leader = self.config.role() { + match &mut self.state { + State::Setup { + private_key, + share_a0, + share_b0, + share_a1, + share_b1, + eq, + } => { + let follower_key = ctx + .io_mut() + .expect_next() + .await + .map_err(KeyExchangeError::io)?; + + self.state = State::SetFollowerKey { + private_key: private_key.clone(), + follower_key, + share_a0: *share_a0, + share_b0: *share_b0, + share_a1: *share_a1, + share_b1: *share_b1, + eq: *eq, + }; + } + State::SetAllKeys { server_key, .. } => { + ctx.io_mut() + .send(*server_key) + .await + .map_err(KeyExchangeError::io)?; + self.compute_ec_shares(ctx).await?; + } + _ => (), } - }; - - let private_key = self - .private_key - .take() - .ok_or(KeyExchangeError::state("private key not set"))?; - - let (pms_share_0, pms_share_1) = self.compute_pms_shares(server_key, private_key).await?; - let pms = self.compute_pms_with(pms_share_0, pms_share_1).await?; - - self.state = State::Complete; - - Ok(pms) + } else if let State::Setup { + private_key, + share_a0, + share_b0, + share_a1, + share_b1, + eq, + } = &mut self.state + { + let follower_key = private_key.public_key(); + ctx.io_mut() + .send(follower_key) + .await + .map_err(KeyExchangeError::io)?; + + let server_key: PublicKey = ctx + .io_mut() + .expect_next() + .await + .map_err(KeyExchangeError::io)?; + + self.state = State::SetAllKeys { + private_key: private_key.clone(), + server_key, + share_a0: *share_a0, + share_b0: *share_b0, + share_a1: *share_a1, + share_b1: *share_b1, + eq: *eq, + }; + self.compute_ec_shares(ctx).await?; + } + Ok(()) } } -async fn compute_pms_shares< - Ctx: Context, - C0: ShareConvert + Send, - C1: ShareConvert + Send, ->( +async fn compute_ec_shares( ctx: &mut Ctx, role: Role, converter_0: &mut C0, converter_1: &mut C1, - server_key: PublicKey, private_key: SecretKey, -) -> Result<(P256, P256), KeyExchangeError> { + server_key: PublicKey, +) -> Result<(P256, P256), KeyExchangeError> +where + Ctx: Context, + C0: ShareConvert + Flush + Send, + >::Future: Send, + >::Future: Send, + C1: ShareConvert + Flush + Send, + >::Future: Send, + >::Future: Send, +{ // Compute the leader's/follower's share of the pre-master secret. // - // We need to mimic the [diffie-hellman](p256::ecdh::diffie_hellman) function - // without the [SharedSecret](p256::ecdh::SharedSecret) wrapper, because - // this makes it harder to get the result as an EC curve point. + // We need to mimic the [diffie-hellman](p256::ecdh::diffie_hellman) function without the + // [SharedSecret](p256::ecdh::SharedSecret) wrapper, because this makes it harder to get the + // result as an EC curve point. let shared_secret = { let public_projective = server_key.to_projective(); (public_projective * private_key.to_nonzero_scalar().as_ref()).to_affine() }; let encoded_point = EncodedPoint::from(PublicKey::from_affine(shared_secret)?); - - let (pms_share_0, pms_share_1) = ctx - .try_join( - |ctx| { - async { derive_x_coord_share(role, ctx, converter_0, encoded_point).await } - .scope_boxed() - }, - |ctx| { - async { derive_x_coord_share(role, ctx, converter_1, encoded_point).await } - .scope_boxed() - }, - ) - .await??; + let pms_share_0 = derive_x_coord_share(ctx, role, converter_0, encoded_point).await?; + let pms_share_1 = derive_x_coord_share(ctx, role, converter_1, encoded_point).await?; + + // TODO: Fix lifetimes here + //let (pms_share_0, pms_share_1) = ctx + // .try_join( + // |ctx| { + // async { derive_x_coord_share(ctx, role, converter_0, encoded_point).await } + // .scope_boxed() + // }, + // |ctx| { + // async { derive_x_coord_share(ctx, role, converter_1, encoded_point).await } + // .scope_boxed() + // }, + // ) + // .await??; Ok((pms_share_0, pms_share_1)) } @@ -450,63 +513,42 @@ async fn compute_pms_shares< #[cfg(test)] mod tests { use super::*; - - use mpz_common::executor::{test_st_executor, STExecutor}; - use mpz_garble::protocol::deap::mock::{create_mock_deap_vm, MockFollower, MockLeader}; - use mpz_share_conversion::ideal::{ideal_share_converter, IdealShareConverter}; + use crate::error::ErrorRepr; + use mpz_common::executor::test_st_executor; + use mpz_core::Block; + use mpz_garble::protocol::semihonest::{Evaluator, Generator}; + use mpz_memory_core::correlated::Delta; + use mpz_ot::ideal::cot::{ideal_cot, IdealCOTReceiver, IdealCOTSender}; + use mpz_share_conversion::ideal::{ + ideal_share_convert, IdealShareConvertReceiver, IdealShareConvertSender, + }; + use mpz_vm_core::Execute; use p256::{NonZeroScalar, PublicKey, SecretKey}; + use rand::rngs::StdRng; use rand_chacha::ChaCha12Rng; use rand_core::SeedableRng; - use serio::channel::MemoryDuplex; - - #[allow(clippy::type_complexity)] - fn create_pair() -> ( - MpcKeyExchange< - STExecutor, - IdealShareConverter, - IdealShareConverter, - MockLeader, - >, - MpcKeyExchange< - STExecutor, - IdealShareConverter, - IdealShareConverter, - MockFollower, - >, - ) { - let (leader_ctx, follower_ctx) = test_st_executor(8); - let (leader_converter_0, follower_converter_0) = ideal_share_converter(); - let (follower_converter_1, leader_converter_1) = ideal_share_converter(); - let (leader_vm, follower_vm) = create_mock_deap_vm(); - - let leader = MpcKeyExchange::new( - KeyExchangeConfig::builder() - .role(Role::Leader) - .build() - .unwrap(), - leader_ctx, - leader_converter_0, - leader_converter_1, - leader_vm, - ); - let follower = MpcKeyExchange::new( - KeyExchangeConfig::builder() - .role(Role::Follower) - .build() - .unwrap(), - follower_ctx, - follower_converter_0, - follower_converter_1, - follower_vm, - ); + impl MpcKeyExchange { + fn set_private_key(&mut self, key: SecretKey) { + let State::Initialized { private_key } = &mut self.state else { + panic!("Can only set private key in initialized state") + }; + *private_key = key; + } - (leader, follower) + fn set_pms_0(&mut self, pms: P256) { + let State::ComputedECShares { pms_0, .. } = &mut self.state else { + panic!("Can only set private key in initialized state") + }; + *pms_0 = pms; + } } #[tokio::test] async fn test_key_exchange() { let mut rng = ChaCha12Rng::from_seed([0_u8; 32]); + let (mut ctx_a, mut ctx_b) = test_st_executor(8); + let (mut gen, mut ev) = mock_vm(); let leader_private_key = SecretKey::random(&mut rng); let follower_private_key = SecretKey::random(&mut rng); @@ -514,28 +556,54 @@ mod tests { let (mut leader, mut follower) = create_pair(); - leader.private_key = Some(leader_private_key.clone()); - follower.private_key = Some(follower_private_key.clone()); + KeyExchange::>::alloc(&mut leader).unwrap(); + KeyExchange::>::alloc(&mut follower).unwrap(); - tokio::try_join!(leader.setup(), follower.setup()).unwrap(); - tokio::try_join!(leader.preprocess(), follower.preprocess()).unwrap(); + leader.set_private_key(leader_private_key.clone()); + follower.set_private_key(follower_private_key.clone()); - let client_public_key = leader.client_key().await.unwrap(); - leader.set_server_key(server_public_key).await.unwrap(); + leader.setup(&mut gen).unwrap(); + follower.setup(&mut ev).unwrap(); - let expected_client_public_key = PublicKey::from_affine( - (leader_private_key.public_key().to_projective() - + follower_private_key.public_key().to_projective()) - .to_affine(), + tokio::try_join!( + async { + leader.flush(&mut ctx_a).await.unwrap(); + + let client_public_key = + KeyExchange::>::client_key(&leader).unwrap(); + + KeyExchange::>::set_server_key( + &mut leader, + server_public_key, + ) + .unwrap(); + + assert_eq!( + KeyExchange::>::server_key(&leader).unwrap(), + server_public_key + ); + + let expected_client_public_key = PublicKey::from_affine( + (leader_private_key.public_key().to_projective() + + follower_private_key.public_key().to_projective()) + .to_affine(), + ) + .unwrap(); + + assert_eq!(client_public_key, expected_client_public_key); + leader.flush(&mut ctx_a).await.unwrap(); + Ok(()) + }, + follower.flush(&mut ctx_b) ) .unwrap(); - - assert_eq!(client_public_key, expected_client_public_key); } #[tokio::test] async fn test_compute_pms() { let mut rng = ChaCha12Rng::from_seed([0_u8; 32]); + let (mut ctx_a, mut ctx_b) = test_st_executor(8); + let (mut gen, mut ev) = mock_vm(); let leader_private_key = SecretKey::random(&mut rng); let follower_private_key = SecretKey::random(&mut rng); @@ -544,27 +612,75 @@ mod tests { let (mut leader, mut follower) = create_pair(); - leader.private_key = Some(leader_private_key); - follower.private_key = Some(follower_private_key); + KeyExchange::>::alloc(&mut leader).unwrap(); + KeyExchange::>::alloc(&mut follower).unwrap(); + + leader.set_private_key(leader_private_key.clone()); + follower.set_private_key(follower_private_key.clone()); - tokio::try_join!(leader.setup(), follower.setup()).unwrap(); - tokio::try_join!(leader.preprocess(), follower.preprocess()).unwrap(); + let leader_pms = leader.setup(&mut gen).unwrap().into_value(); + let leader_pms = gen.decode(leader_pms).unwrap(); - leader.set_server_key(server_public_key).await.unwrap(); + let follower_pms = follower.setup(&mut ev).unwrap().into_value(); + let follower_pms = ev.decode(follower_pms).unwrap(); - let (_leader_pms, _follower_pms) = - tokio::try_join!(leader.compute_pms(), follower.compute_pms()).unwrap(); + tokio::try_join!( + async { + leader.flush(&mut ctx_a).await.unwrap(); + let _client_public_key = + KeyExchange::>::client_key(&leader).unwrap(); + + KeyExchange::>::set_server_key( + &mut leader, + server_public_key, + ) + .unwrap(); + assert_eq!( + KeyExchange::>::server_key(&leader).unwrap(), + server_public_key + ); + leader.flush(&mut ctx_a).await.unwrap(); + Ok(()) + }, + follower.flush(&mut ctx_b) + ) + .unwrap(); + + let eq_check_leader = leader.compute_pms(&mut gen).unwrap(); + let eq_check_follower = follower.compute_pms(&mut ev).unwrap(); - assert_eq!(leader.server_key.unwrap(), server_public_key); - assert_eq!(follower.server_key.unwrap(), server_public_key); + tokio::try_join!( + async { + gen.flush(&mut ctx_a).await.unwrap(); + gen.execute(&mut ctx_a).await.unwrap(); + gen.flush(&mut ctx_a) + .await + .map_err(KeyExchangeError::vm) + .unwrap(); + eq_check_leader.check().await + }, + async { + ev.flush(&mut ctx_b).await.unwrap(); + ev.execute(&mut ctx_b).await.unwrap(); + ev.flush(&mut ctx_b) + .await + .map_err(KeyExchangeError::vm) + .unwrap(); + eq_check_follower.check().await + } + ) + .unwrap(); + + let (leader_pms, follower_pms) = tokio::try_join!(leader_pms, follower_pms).unwrap(); + assert_eq!(leader_pms, follower_pms); } #[tokio::test] - async fn test_compute_pms_shares() { + async fn test_compute_ec_shares() { let mut rng = ChaCha12Rng::from_seed([0_u8; 32]); let (mut ctx_leader, mut ctx_follower) = test_st_executor(8); - let (mut leader_converter_0, mut follower_converter_0) = ideal_share_converter(); - let (mut follower_converter_1, mut leader_converter_1) = ideal_share_converter(); + let (mut leader_converter_0, mut follower_converter_0) = ideal_share_convert(Block::ZERO); + let (mut follower_converter_1, mut leader_converter_1) = ideal_share_convert(Block::ZERO); let leader_private_key = SecretKey::random(&mut rng); let follower_private_key = SecretKey::random(&mut rng); @@ -580,21 +696,21 @@ mod tests { let ((leader_share_0, leader_share_1), (follower_share_0, follower_share_1)) = tokio::try_join!( - compute_pms_shares( + compute_ec_shares( &mut ctx_leader, Role::Leader, &mut leader_converter_0, &mut leader_converter_1, - server_public_key, - leader_private_key + leader_private_key, + server_public_key ), - compute_pms_shares( + compute_ec_shares( &mut ctx_follower, Role::Follower, &mut follower_converter_0, &mut follower_converter_1, - server_public_key, - follower_private_key + follower_private_key, + server_public_key ) ) .unwrap(); @@ -618,6 +734,8 @@ mod tests { #[tokio::test] async fn test_compute_pms_fail() { let mut rng = ChaCha12Rng::from_seed([0_u8; 32]); + let (mut ctx_a, mut ctx_b) = test_st_executor(8); + let (mut gen, mut ev) = mock_vm(); let leader_private_key = SecretKey::random(&mut rng); let follower_private_key = SecretKey::random(&mut rng); @@ -626,31 +744,215 @@ mod tests { let (mut leader, mut follower) = create_pair(); - leader.private_key = Some(leader_private_key.clone()); - follower.private_key = Some(follower_private_key.clone()); + KeyExchange::>::alloc(&mut leader).unwrap(); + KeyExchange::>::alloc(&mut follower).unwrap(); - tokio::try_join!(leader.setup(), follower.setup()).unwrap(); - tokio::try_join!(leader.preprocess(), follower.preprocess()).unwrap(); + leader.set_private_key(leader_private_key.clone()); + follower.set_private_key(follower_private_key.clone()); - leader.set_server_key(server_public_key).await.unwrap(); + leader.setup(&mut gen).unwrap(); + follower.setup(&mut ev).unwrap(); - let ((mut share_a0, share_a1), (share_b0, share_b1)) = tokio::try_join!( - leader.compute_pms_shares(server_public_key, leader_private_key), - follower.compute_pms_shares(server_public_key, follower_private_key) + tokio::try_join!( + async { + leader.flush(&mut ctx_a).await.unwrap(); + let _client_public_key = + KeyExchange::>::client_key(&leader).unwrap(); + + KeyExchange::>::set_server_key( + &mut leader, + server_public_key, + ) + .unwrap(); + assert_eq!( + KeyExchange::>::server_key(&leader).unwrap(), + server_public_key + ); + leader.flush(&mut ctx_a).await.unwrap(); + Ok(()) + }, + follower.flush(&mut ctx_b) ) .unwrap(); - share_a0 = share_a0 + P256::one(); + // Now manipulate pms + leader.set_pms_0(P256::one()); + follower.set_pms_0(P256::one()); + + let eq_check_leader = leader.compute_pms(&mut gen).unwrap(); + let eq_check_follower = follower.compute_pms(&mut ev).unwrap(); let (leader_res, follower_res) = tokio::join!( - leader.compute_pms_with(share_a0, share_a1), - follower.compute_pms_with(share_b0, share_b1) + async { + gen.flush(&mut ctx_a).await.unwrap(); + gen.execute(&mut ctx_a).await.unwrap(); + gen.flush(&mut ctx_a) + .await + .map_err(KeyExchangeError::vm) + .unwrap(); + eq_check_leader.check().await + }, + async { + ev.flush(&mut ctx_b).await.unwrap(); + ev.execute(&mut ctx_b).await.unwrap(); + ev.flush(&mut ctx_b) + .await + .map_err(KeyExchangeError::vm) + .unwrap(); + eq_check_follower.check().await + } ); let leader_err = leader_res.unwrap_err(); let follower_err = follower_res.unwrap_err(); - assert!(matches!(leader_err.kind(), ErrorKind::ShareConversion)); - assert!(matches!(follower_err.kind(), ErrorKind::ShareConversion)); + assert!(matches!(leader_err.kind(), ErrorRepr::ShareConversion(_))); + assert!(matches!(follower_err.kind(), ErrorRepr::ShareConversion(_))); + } + + #[tokio::test] + async fn test_circuit() { + let (mut ctx_a, mut ctx_b) = test_st_executor(8); + let (gen, ev) = mock_vm(); + + let share_a0_bytes = [5_u8; 32]; + let share_a1_bytes = [2_u8; 32]; + + let share_b0_bytes = [3_u8; 32]; + let share_b1_bytes = [6_u8; 32]; + + let (res_gen, res_ev) = tokio::join!( + async move { + let mut vm = gen; + let share_a0: Array = vm.alloc().unwrap(); + vm.mark_private(share_a0).unwrap(); + + let share_b0: Array = vm.alloc().unwrap(); + vm.mark_blind(share_b0).unwrap(); + + let share_a1: Array = vm.alloc().unwrap(); + vm.mark_private(share_a1).unwrap(); + + let share_b1: Array = vm.alloc().unwrap(); + vm.mark_blind(share_b1).unwrap(); + + let pms_circuit = build_pms_circuit(); + let pms_call = CallBuilder::new(pms_circuit) + .arg(share_a0) + .arg(share_b0) + .arg(share_a1) + .arg(share_b1) + .build() + .unwrap(); + + let (_, _, eq): (Array, Array, Array) = + vm.call(pms_call).unwrap(); + + vm.assign(share_a0, share_a0_bytes).unwrap(); + vm.commit(share_a0).unwrap(); + + vm.assign(share_a1, share_a1_bytes).unwrap(); + vm.commit(share_a1).unwrap(); + + vm.commit(share_b0).unwrap(); + vm.commit(share_b1).unwrap(); + + let check = vm.decode(eq).unwrap(); + + vm.flush(&mut ctx_a).await.unwrap(); + vm.execute(&mut ctx_a).await.unwrap(); + vm.flush(&mut ctx_a).await.unwrap(); + check.await + }, + async { + let mut vm = ev; + let share_a0: Array = vm.alloc().unwrap(); + vm.mark_blind(share_a0).unwrap(); + + let share_b0: Array = vm.alloc().unwrap(); + vm.mark_private(share_b0).unwrap(); + + let share_a1: Array = vm.alloc().unwrap(); + vm.mark_blind(share_a1).unwrap(); + + let share_b1: Array = vm.alloc().unwrap(); + vm.mark_private(share_b1).unwrap(); + + let pms_circuit = build_pms_circuit(); + let pms_call = CallBuilder::new(pms_circuit) + .arg(share_a0) + .arg(share_b0) + .arg(share_a1) + .arg(share_b1) + .build() + .unwrap(); + + let (_, _, eq): (Array, Array, Array) = + vm.call(pms_call).unwrap(); + + vm.assign(share_b0, share_b0_bytes).unwrap(); + vm.commit(share_b0).unwrap(); + + vm.assign(share_b1, share_b1_bytes).unwrap(); + vm.commit(share_b1).unwrap(); + + vm.commit(share_a0).unwrap(); + vm.commit(share_a1).unwrap(); + + let check = vm.decode(eq).unwrap(); + + vm.flush(&mut ctx_b).await.unwrap(); + vm.execute(&mut ctx_b).await.unwrap(); + vm.flush(&mut ctx_b).await.unwrap(); + check.await + } + ); + + let res_gen = res_gen.unwrap(); + let res_ev = res_ev.unwrap(); + + assert_eq!(res_gen, res_ev); + assert_eq!(res_gen, [0_u8; 32]); + } + + #[allow(clippy::type_complexity)] + fn create_pair() -> ( + MpcKeyExchange, IdealShareConvertReceiver>, + MpcKeyExchange, IdealShareConvertSender>, + ) { + let (leader_converter_0, follower_converter_0) = ideal_share_convert(Block::ZERO); + let (follower_converter_1, leader_converter_1) = ideal_share_convert(Block::ZERO); + + let leader = MpcKeyExchange::new( + KeyExchangeConfig::builder() + .role(Role::Leader) + .build() + .unwrap(), + leader_converter_0, + leader_converter_1, + ); + + let follower = MpcKeyExchange::new( + KeyExchangeConfig::builder() + .role(Role::Follower) + .build() + .unwrap(), + follower_converter_0, + follower_converter_1, + ); + + (leader, follower) + } + + fn mock_vm() -> (Generator, Evaluator) { + let mut rng = StdRng::seed_from_u64(0); + let delta = Delta::random(&mut rng); + + let (cot_send, cot_recv) = ideal_cot(delta.into_inner()); + + let gen = Generator::new(cot_send, [0u8; 16], delta); + let ev = Evaluator::new(cot_recv); + + (gen, ev) } } diff --git a/crates/components/key-exchange/src/lib.rs b/crates/components/key-exchange/src/lib.rs index 1215ebfe9..8bf8092c1 100644 --- a/crates/components/key-exchange/src/lib.rs +++ b/crates/components/key-exchange/src/lib.rs @@ -28,50 +28,70 @@ pub use config::{ pub use error::KeyExchangeError; pub use exchange::MpcKeyExchange; -use async_trait::async_trait; -use mpz_garble::value::ValueRef; +use mpz_core::bitvec::BitVec; +use mpz_memory_core::{binary::U8, Array, DecodeFutureTyped}; use p256::PublicKey; /// Pre-master secret. -#[derive(Debug, Clone)] -pub struct Pms(ValueRef); +#[derive(Debug, Clone, Copy)] +pub struct Pms(Array); impl Pms { /// Creates a new PMS. - pub fn new(value: ValueRef) -> Self { - Self(value) + pub fn new(pms: Array) -> Self { + Self(pms) } /// Gets the value of the PMS. - pub fn into_value(self) -> ValueRef { + pub fn into_value(self) -> Array { self.0 } } +/// Checks that both parties behaved honestly. +#[must_use] +#[derive(Debug)] +pub struct EqualityCheck(DecodeFutureTyped); + +impl EqualityCheck { + /// Checks that the PMS computation succeeded and that both parties agree on the PMS value. + /// + /// This MUST be called to ensure that no party cheated. + pub async fn check(self) -> Result<(), KeyExchangeError> { + let eq = self.0.await.map_err(KeyExchangeError::vm)?; + + // Eq should be all zeros if pms_1 == pms_2. + if eq != [0u8; 32] { + return Err(KeyExchangeError::share_conversion("PMS values not equal")); + } + Ok(()) + } +} + /// A trait for the 3-party key exchange protocol. -#[async_trait] -pub trait KeyExchange { - /// Gets the server's public key. - fn server_key(&self) -> Option; +pub trait KeyExchange { + /// Allocate necessary computational resources. + fn alloc(&mut self) -> Result<(), KeyExchangeError>; /// Sets the server's public key. - async fn set_server_key(&mut self, server_key: PublicKey) -> Result<(), KeyExchangeError>; + fn set_server_key(&mut self, server_key: PublicKey) -> Result<(), KeyExchangeError>; + + /// Gets the server's public key. + fn server_key(&self) -> Option; /// Computes the client's public key. /// /// The client's public key in this context is the combined public key (EC /// point addition) of the leader's public key and the follower's public /// key. - async fn client_key(&mut self) -> Result; + fn client_key(&self) -> Result; /// Performs any necessary one-time setup, returning a reference to the PMS. - /// - /// The PMS will not be assigned until `compute_pms` is called. - async fn setup(&mut self) -> Result; - - /// Preprocesses the key exchange. - async fn preprocess(&mut self) -> Result<(), KeyExchangeError>; + fn setup(&mut self, vm: &mut V) -> Result; - /// Computes the PMS. - async fn compute_pms(&mut self) -> Result; + /// Computes the PMS, and returns an equality check. + /// + /// The equality check makes sure that both parties arrived at the same result. This MUST be + /// called to prevent malicious behavior! + fn compute_pms(&mut self, vm: &mut V) -> Result; } diff --git a/crates/components/key-exchange/src/mock.rs b/crates/components/key-exchange/src/mock.rs index 00d9fe997..68693ff18 100644 --- a/crates/components/key-exchange/src/mock.rs +++ b/crates/components/key-exchange/src/mock.rs @@ -2,24 +2,20 @@ //! function to create such a pair. use crate::{KeyExchangeConfig, MpcKeyExchange, Role}; - -use mpz_common::executor::{test_st_executor, STExecutor}; -use mpz_garble::{Decode, Execute, Memory}; -use mpz_share_conversion::ideal::{ideal_share_converter, IdealShareConverter}; -use serio::channel::MemoryDuplex; +use mpz_core::Block; +use mpz_fields::p256::P256; +use mpz_share_conversion::ideal::{ + ideal_share_convert, IdealShareConvertReceiver, IdealShareConvertSender, +}; /// A mock key exchange instance. -pub type MockKeyExchange = - MpcKeyExchange, IdealShareConverter, IdealShareConverter, E>; +pub type MockKeyExchange = + MpcKeyExchange, IdealShareConvertReceiver>; /// Creates a mock pair of key exchange leader and follower. -pub fn create_mock_key_exchange_pair( - leader_executor: E, - follower_executor: E, -) -> (MockKeyExchange, MockKeyExchange) { - let (leader_ctx, follower_ctx) = test_st_executor(8); - let (leader_converter_0, follower_converter_0) = ideal_share_converter(); - let (leader_converter_1, follower_converter_1) = ideal_share_converter(); +pub fn create_mock_key_exchange_pair() -> (MockKeyExchange, MockKeyExchange) { + let (leader_converter_0, follower_converter_0) = ideal_share_convert(Block::ZERO); + let (follower_converter_1, leader_converter_1) = ideal_share_convert(Block::ZERO); let key_exchange_config_leader = KeyExchangeConfig::builder() .role(Role::Leader) @@ -33,18 +29,14 @@ pub fn create_mock_key_exchange_pair( let leader = MpcKeyExchange::new( key_exchange_config_leader, - leader_ctx, leader_converter_0, leader_converter_1, - leader_executor, ); let follower = MpcKeyExchange::new( key_exchange_config_follower, - follower_ctx, - follower_converter_0, follower_converter_1, - follower_executor, + follower_converter_0, ); (leader, follower) @@ -52,20 +44,29 @@ pub fn create_mock_key_exchange_pair( #[cfg(test)] mod tests { - use mpz_garble::protocol::deap::mock::create_mock_deap_vm; - - use crate::KeyExchange; + use mpz_common::executor::TestSTExecutor; + use mpz_garble::protocol::semihonest::{Evaluator, Generator}; + use mpz_ot::ideal::cot::{IdealCOTReceiver, IdealCOTSender}; use super::*; + use crate::KeyExchange; #[test] fn test_mock_is_ke() { - let (leader_vm, follower_vm) = create_mock_deap_vm(); - let (leader, follower) = create_mock_key_exchange_pair(leader_vm, follower_vm); + let (leader, follower) = create_mock_key_exchange_pair(); + + fn is_key_exchange, Ctx, V>(_: T) {} - fn is_key_exchange(_: T) {} + is_key_exchange::< + MpcKeyExchange, IdealShareConvertReceiver>, + TestSTExecutor, + Generator, + >(leader); - is_key_exchange(leader); - is_key_exchange(follower); + is_key_exchange::< + MpcKeyExchange, IdealShareConvertReceiver>, + TestSTExecutor, + Evaluator, + >(follower); } } diff --git a/crates/components/key-exchange/src/point_addition.rs b/crates/components/key-exchange/src/point_addition.rs index d95056d8a..b3670ffbb 100644 --- a/crates/components/key-exchange/src/point_addition.rs +++ b/crates/components/key-exchange/src/point_addition.rs @@ -1,27 +1,28 @@ -//! This module implements a secure two-party computation protocol for adding -//! two private EC points and secret-sharing the resulting x coordinate (the -//! shares are field elements of the field underlying the elliptic curve). -//! This protocol has semi-honest security. +//! This module implements a secure two-party computation protocol for adding two private EC points +//! and secret-sharing the resulting x coordinate (the shares are field elements of the field +//! underlying the elliptic curve). This protocol has semi-honest security. //! -//! The protocol is described in +//! The protocol is described in +//! -use mpz_common::Context; +use crate::{config::Role, KeyExchangeError}; +use mpz_common::{Context, Flush}; use mpz_fields::{p256::P256, Field}; -use mpz_share_conversion::{AdditiveToMultiplicative, MultiplicativeToAdditive}; +use mpz_share_conversion::{AdditiveToMultiplicative, MultiplicativeToAdditive, ShareConvert}; use p256::EncodedPoint; -use crate::{config::Role, error::ErrorKind, KeyExchangeError}; - /// Derives the x-coordinate share of an elliptic curve point. pub(crate) async fn derive_x_coord_share( - role: Role, ctx: &mut Ctx, + role: Role, converter: &mut C, share: EncodedPoint, ) -> Result where Ctx: Context, - C: AdditiveToMultiplicative + MultiplicativeToAdditive, + C: ShareConvert + Flush + Send, + >::Future: Send, + >::Future: Send, { let [x, y] = decompose_point(share)?; @@ -31,16 +32,40 @@ where Role::Follower => vec![-y, -x], }; - let [a, b] = converter - .to_multiplicative(ctx, inputs) - .await? + let a2m = converter + .queue_to_multiplicative(&inputs) + .map_err(KeyExchangeError::share_conversion)?; + + converter + .flush(ctx) + .await + .map_err(KeyExchangeError::share_conversion)?; + + let [a, b] = a2m + .await + .map_err(KeyExchangeError::share_conversion)? + .shares .try_into() .expect("output is same length as input"); - let c = a * b.inverse(); + let c = a * b + .inverse() + .expect("field element should not be zero when inverting"); let c = c * c; - let d = converter.to_additive(ctx, vec![c]).await?[0]; + let m2a = converter + .queue_to_additive(&[c]) + .map_err(KeyExchangeError::share_conversion)?; + + converter + .flush(ctx) + .await + .map_err(KeyExchangeError::share_conversion)?; + + let d = m2a + .await + .map_err(KeyExchangeError::share_conversion)? + .shares[0]; let x_r = d + -x; @@ -50,13 +75,11 @@ where /// Decomposes the x and y coordinates of a SEC1 encoded point. fn decompose_point(point: EncodedPoint) -> Result<[P256; 2], KeyExchangeError> { // Coordinates are stored as big-endian bytes. - let mut x: [u8; 32] = (*point.x().ok_or(KeyExchangeError::new( - ErrorKind::Key, - "key share is an identity point", - ))?) + let mut x: [u8; 32] = (*point + .x() + .ok_or(KeyExchangeError::key("key share is an identity point"))?) .into(); - let mut y: [u8; 32] = (*point.y().ok_or(KeyExchangeError::new( - ErrorKind::Key, + let mut y: [u8; 32] = (*point.y().ok_or(KeyExchangeError::key( "key share is an identity point or compressed", ))?) .into(); @@ -76,8 +99,9 @@ mod tests { use super::*; use mpz_common::executor::test_st_executor; + use mpz_core::Block; use mpz_fields::{p256::P256, Field}; - use mpz_share_conversion::ideal::ideal_share_converter; + use mpz_share_conversion::ideal::ideal_share_convert; use p256::{ elliptic_curve::sec1::{FromEncodedPoint, ToEncodedPoint}, EncodedPoint, NonZeroScalar, ProjectivePoint, PublicKey, @@ -98,11 +122,11 @@ mod tests { let p = add_curve_points(&p1, &p2); - let (mut c_a, mut c_b) = ideal_share_converter(); + let (mut c_a, mut c_b) = ideal_share_convert(Block::ZERO); let (a, b) = tokio::try_join!( - derive_x_coord_share(Role::Leader, &mut ctx_a, &mut c_a, p1), - derive_x_coord_share(Role::Follower, &mut ctx_b, &mut c_b, p2) + derive_x_coord_share(&mut ctx_a, Role::Leader, &mut c_a, p1), + derive_x_coord_share(&mut ctx_b, Role::Follower, &mut c_b, p2) ) .unwrap(); From d5eadc22e97bf1eec068defb3efef6cee1581239 Mon Sep 17 00:00:00 2001 From: th4s Date: Thu, 16 Jan 2025 10:26:23 +0100 Subject: [PATCH 2/5] fix: fix feature flags --- crates/components/key-exchange/Cargo.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/components/key-exchange/Cargo.toml b/crates/components/key-exchange/Cargo.toml index 21cd9679a..a8aafecc6 100644 --- a/crates/components/key-exchange/Cargo.toml +++ b/crates/components/key-exchange/Cargo.toml @@ -13,14 +13,14 @@ name = "key_exchange" [features] default = ["mock"] -mock = [] +mock = ["mpz-share-conversion/test-utils", "mpz-common/ideal"] [dependencies] mpz-vm-core = { workspace = true } mpz-memory-core = { workspace = true } mpz-common = { workspace = true } mpz-fields = { workspace = true } -mpz-share-conversion = { workspace = true, features = ["test-utils"] } +mpz-share-conversion = { workspace = true } mpz-circuits = { workspace = true } mpz-core = { workspace = true } From ce7d2c4939167159336203c2b4fbce2b95371f63 Mon Sep 17 00:00:00 2001 From: sinu <65924192+sinui0@users.noreply.github.com> Date: Tue, 21 Jan 2025 15:59:02 -0800 Subject: [PATCH 3/5] simplify --- crates/components/key-exchange/Cargo.toml | 5 +- crates/components/key-exchange/src/error.rs | 13 +- .../components/key-exchange/src/exchange.rs | 523 ++++++++---------- crates/components/key-exchange/src/lib.rs | 41 +- crates/components/key-exchange/src/mock.rs | 2 +- 5 files changed, 239 insertions(+), 345 deletions(-) diff --git a/crates/components/key-exchange/Cargo.toml b/crates/components/key-exchange/Cargo.toml index a8aafecc6..7f138fc22 100644 --- a/crates/components/key-exchange/Cargo.toml +++ b/crates/components/key-exchange/Cargo.toml @@ -16,7 +16,7 @@ default = ["mock"] mock = ["mpz-share-conversion/test-utils", "mpz-common/ideal"] [dependencies] -mpz-vm-core = { workspace = true } +mpz-vm-core = { workspace = true } mpz-memory-core = { workspace = true } mpz-common = { workspace = true } mpz-fields = { workspace = true } @@ -31,9 +31,10 @@ serio = { workspace = true } derive_builder = { workspace = true } tracing = { workspace = true } rand = { workspace = true } +tokio = { workspace = true, features = ["sync"] } [dev-dependencies] -mpz-ot = { workspace = true , features = ["ideal"] } +mpz-ot = { workspace = true, features = ["ideal"] } mpz-garble = { workspace = true } rand_chacha = { workspace = true } diff --git a/crates/components/key-exchange/src/error.rs b/crates/components/key-exchange/src/error.rs index 50a06c8ee..a91a213de 100644 --- a/crates/components/key-exchange/src/error.rs +++ b/crates/components/key-exchange/src/error.rs @@ -12,7 +12,7 @@ pub(crate) enum ErrorRepr { /// Context error. Ctx(Box), /// IO related error - Io(Box), + Io(std::io::Error), /// Virtual machine error Vm(Box), /// Share conversion error @@ -52,13 +52,6 @@ impl KeyExchangeError { Self(ErrorRepr::Ctx(err.into())) } - pub(crate) fn io(err: E) -> KeyExchangeError - where - E: Into>, - { - Self(ErrorRepr::Io(err.into())) - } - pub(crate) fn vm(err: E) -> KeyExchangeError where E: Into>, @@ -106,7 +99,7 @@ impl From for KeyExchangeError { } impl From for KeyExchangeError { - fn from(value: std::io::Error) -> Self { - Self::io(value) + fn from(err: std::io::Error) -> Self { + Self(ErrorRepr::Io(err)) } } diff --git a/crates/components/key-exchange/src/exchange.rs b/crates/components/key-exchange/src/exchange.rs index e98c073c7..706f03423 100644 --- a/crates/components/key-exchange/src/exchange.rs +++ b/crates/components/key-exchange/src/exchange.rs @@ -1,44 +1,41 @@ //! This module implements the key exchange logic. +use std::{fmt::Debug, sync::Arc}; + use async_trait::async_trait; -use mpz_common::{Context, Flush}; +use p256::{EncodedPoint, PublicKey, SecretKey}; +use serio::{sink::SinkExt, stream::IoStreamExt}; +use tokio::sync::Mutex; +use tracing::instrument; +use mpz_common::{scoped_futures::ScopedFutureExt, Context, Flush}; +use mpz_core::bitvec::BitVec; use mpz_fields::{p256::P256, Field}; use mpz_memory_core::{ binary::{Binary, U8}, - Array, Memory, MemoryExt, View, ViewExt, + Array, DecodeFutureTyped, Memory, MemoryExt, View, ViewExt, }; use mpz_share_conversion::{AdditiveToMultiplicative, MultiplicativeToAdditive, ShareConvert}; use mpz_vm_core::{CallBuilder, Vm, VmExt}; -use p256::{EncodedPoint, PublicKey, SecretKey}; -use serio::{sink::SinkExt, stream::IoStreamExt}; -use std::fmt::Debug; -use tracing::instrument; use crate::{ circuit::build_pms_circuit, config::{KeyExchangeConfig, Role}, point_addition::derive_x_coord_share, - EqualityCheck, KeyExchange, KeyExchangeError, Pms, + KeyExchange, KeyExchangeError, Pms, }; #[derive(Debug)] enum State { - Initialized { - /// The private key of the party behind this instance, either follower or leader. - private_key: SecretKey, - }, + Initialized, Setup { - private_key: SecretKey, share_a0: Array, share_b0: Array, share_a1: Array, share_b1: Array, eq: Array, }, - SetFollowerKey { - private_key: SecretKey, - /// The public key of the follower + FollowerKey { follower_key: PublicKey, share_a0: Array, share_b0: Array, @@ -46,18 +43,7 @@ enum State { share_b1: Array, eq: Array, }, - SetAllKeys { - private_key: SecretKey, - /// The public key of the server. - server_key: PublicKey, - share_a0: Array, - share_b0: Array, - share_a1: Array, - share_b1: Array, - eq: Array, - }, ComputedECShares { - server_key: PublicKey, share_a0: Array, share_b0: Array, share_a1: Array, @@ -66,10 +52,19 @@ enum State { pms_0: P256, pms_1: P256, }, + EqualityCheck { + eq: DecodeFutureTyped, + }, Complete, Error, } +impl State { + fn take(&mut self) -> Self { + std::mem::replace(self, Self::Error) + } +} + /// An MPC key exchange protocol. /// /// Can be either a leader or a follower depending on the `role` field in @@ -77,13 +72,17 @@ enum State { #[derive(Debug)] pub struct MpcKeyExchange { /// Share conversion protocol 0. - converter_0: C0, + converter_0: Arc>, /// Share conversion protocol 1. - converter_1: C1, + converter_1: Arc>, /// The config used for the key exchange protocol. config: KeyExchangeConfig, /// The state of the protocol. state: State, + /// This party's private key. + private_key: SecretKey, + /// Server's public key. + server_key: Option, } impl MpcKeyExchange { @@ -98,78 +97,34 @@ impl MpcKeyExchange { let private_key = SecretKey::random(&mut rand::rngs::OsRng); Self { - converter_0, - converter_1, + converter_0: Arc::new(Mutex::new(converter_0)), + converter_1: Arc::new(Mutex::new(converter_1)), config, - state: State::Initialized { private_key }, - } - } - - async fn compute_ec_shares(&mut self, ctx: &mut Ctx) -> Result<(), KeyExchangeError> - where - Ctx: Context, - C0: ShareConvert + Flush + Send, - >::Future: Send, - >::Future: Send, - C1: ShareConvert + Flush + Send, - >::Future: Send, - >::Future: Send, - { - let State::SetAllKeys { - private_key, - server_key, - share_a0, - share_b0, - share_a1, - share_b1, - eq, - .. - } = std::mem::replace(&mut self.state, State::Error) - else { - return Err(KeyExchangeError::state( - "should be in SetAllKeys state to compute pms", - )); - }; - let (pms_0, pms_1) = compute_ec_shares( - ctx, - self.config.role(), - &mut self.converter_0, - &mut self.converter_1, + state: State::Initialized, private_key, - server_key, - ) - .await?; - - self.state = State::ComputedECShares { - server_key, - share_a0, - share_b0, - share_a1, - share_b1, - eq, - pms_0, - pms_1, - }; - Ok(()) + server_key: None, + } } } -impl KeyExchange for MpcKeyExchange +impl KeyExchange for MpcKeyExchange where - V: Vm + Memory + View + Send, C0: ShareConvert + Send, C1: ShareConvert + Send, { fn alloc(&mut self) -> Result<(), KeyExchangeError> { + let mut converter_0 = self.converter_0.try_lock().unwrap(); + let mut converter_1 = self.converter_1.try_lock().unwrap(); + // 2 A2M, 1 M2A. - >::alloc(&mut self.converter_0, 1) + MultiplicativeToAdditive::alloc(&mut *converter_0, 1) .map_err(KeyExchangeError::share_conversion)?; - >::alloc(&mut self.converter_1, 1) + MultiplicativeToAdditive::alloc(&mut *converter_1, 1) .map_err(KeyExchangeError::share_conversion)?; - >::alloc(&mut self.converter_0, 2) + AdditiveToMultiplicative::alloc(&mut *converter_0, 2) .map_err(KeyExchangeError::share_conversion)?; - >::alloc(&mut self.converter_1, 2) + AdditiveToMultiplicative::alloc(&mut *converter_1, 2) .map_err(KeyExchangeError::share_conversion)?; Ok(()) @@ -180,39 +135,13 @@ where return Err(KeyExchangeError::role("follower cannot set server key")); }; - let State::SetFollowerKey { - private_key, - share_a0, - share_b0, - share_a1, - share_b1, - eq, - .. - } = std::mem::replace(&mut self.state, State::Error) - else { - return Err(KeyExchangeError::state( - "leader must be in SetFollowerKey state to set the server key", - )); - }; + self.server_key = Some(server_key); - self.state = State::SetAllKeys { - private_key, - server_key, - share_a0, - share_b0, - share_a1, - share_b1, - eq, - }; Ok(()) } fn server_key(&self) -> Option { - match self.state { - State::SetAllKeys { server_key, .. } => Some(server_key), - State::ComputedECShares { server_key, .. } => Some(server_key), - _ => None, - } + self.server_key } #[instrument(level = "debug", skip_all, err)] @@ -221,18 +150,13 @@ where return Err(KeyExchangeError::role("follower does not learn client key")); }; - let State::SetFollowerKey { - private_key, - follower_key, - .. - } = &self.state - else { + let State::FollowerKey { follower_key, .. } = &self.state else { return Err(KeyExchangeError::state( - "leader should be in SetFollowerKey state for returning the client key", + "leader should be in FollowerKey state for returning the client key", )); }; - let public_key = private_key.public_key(); + let public_key = self.private_key.public_key(); // Combine public keys. let client_public_key = PublicKey::from_affine( @@ -243,13 +167,16 @@ where } #[instrument(level = "debug", skip_all, err)] - fn setup(&mut self, vm: &mut V) -> Result { - let State::Initialized { private_key } = std::mem::replace(&mut self.state, State::Error) - else { + fn setup(&mut self, vm: &mut V) -> Result + where + V: Vm + Memory + View, + { + let State::Initialized = self.state.take() else { return Err(KeyExchangeError::state( "should be in Initialized state to call setup", )); }; + let (share_a0, share_b0, share_a1, share_b1) = match self.config.role() { Role::Leader => { let share_a0: Array = vm.alloc().map_err(KeyExchangeError::vm)?; @@ -301,14 +228,16 @@ where share_a1, share_b1, eq, - private_key, }; Ok(Pms::new(pms)) } #[instrument(level = "debug", skip_all, err)] - fn compute_pms(&mut self, vm: &mut V) -> Result { + fn compute_pms(&mut self, vm: &mut V) -> Result<(), KeyExchangeError> + where + V: Vm + Memory + View, + { let State::ComputedECShares { share_a0, share_b0, @@ -361,11 +290,11 @@ where } } - let check = vm.decode(eq).map_err(KeyExchangeError::vm)?; - let check = EqualityCheck(check); + let eq = vm.decode(eq).map_err(KeyExchangeError::vm)?; - self.state = State::Complete; - Ok(check) + self.state = State::EqualityCheck { eq }; + + Ok(()) } } @@ -373,91 +302,106 @@ where impl Flush for MpcKeyExchange where Ctx: Context, - C0: ShareConvert + Flush + Send, + C0: ShareConvert + Flush + Send + 'static, >::Future: Send, >::Future: Send, - C1: ShareConvert + Flush + Send, + C1: ShareConvert + Flush + Send + 'static, >::Future: Send, >::Future: Send, { type Error = KeyExchangeError; fn wants_flush(&self) -> bool { - if let Role::Leader = self.config.role() { - matches!(self.state, State::Setup { .. } | State::SetAllKeys { .. }) - } else { - matches!(self.state, State::Setup { .. }) - } + !matches!( + self.state, + State::Initialized | State::Complete | State::Error + ) } async fn flush(&mut self, ctx: &mut Ctx) -> Result<(), Self::Error> { - if let Role::Leader = self.config.role() { - match &mut self.state { - State::Setup { - private_key, + if !self.wants_flush() { + return Ok(()); + } + + self.state = match self.state.take() { + State::Setup { + share_a0, + share_b0, + share_a1, + share_b1, + eq, + } => { + let follower_key = match self.config.role() { + Role::Leader => ctx.io_mut().expect_next().await?, + Role::Follower => { + let follower_key = self.private_key.public_key(); + ctx.io_mut().send(follower_key).await?; + follower_key + } + }; + + State::FollowerKey { + follower_key, share_a0, share_b0, share_a1, share_b1, eq, - } => { - let follower_key = ctx - .io_mut() - .expect_next() - .await - .map_err(KeyExchangeError::io)?; - - self.state = State::SetFollowerKey { - private_key: private_key.clone(), - follower_key, - share_a0: *share_a0, - share_b0: *share_b0, - share_a1: *share_a1, - share_b1: *share_b1, - eq: *eq, - }; } - State::SetAllKeys { server_key, .. } => { - ctx.io_mut() - .send(*server_key) - .await - .map_err(KeyExchangeError::io)?; - self.compute_ec_shares(ctx).await?; + } + State::FollowerKey { + share_a0, + share_b0, + share_a1, + share_b1, + eq, + .. + } => { + let server_key = match self.config.role() { + Role::Leader => { + let server_key = self + .server_key + .ok_or_else(|| KeyExchangeError::role("server key is not set"))?; + + ctx.io_mut().send(server_key).await?; + + server_key + } + Role::Follower => ctx.io_mut().expect_next().await?, + }; + + let (pms_0, pms_1) = compute_ec_shares( + ctx, + self.config.role(), + self.converter_0.clone(), + self.converter_1.clone(), + self.private_key.clone(), + server_key, + ) + .await?; + + State::ComputedECShares { + share_a0, + share_b0, + share_a1, + share_b1, + eq, + pms_0, + pms_1, } - _ => (), } - } else if let State::Setup { - private_key, - share_a0, - share_b0, - share_a1, - share_b1, - eq, - } = &mut self.state - { - let follower_key = private_key.public_key(); - ctx.io_mut() - .send(follower_key) - .await - .map_err(KeyExchangeError::io)?; - - let server_key: PublicKey = ctx - .io_mut() - .expect_next() - .await - .map_err(KeyExchangeError::io)?; - - self.state = State::SetAllKeys { - private_key: private_key.clone(), - server_key, - share_a0: *share_a0, - share_b0: *share_b0, - share_a1: *share_a1, - share_b1: *share_b1, - eq: *eq, - }; - self.compute_ec_shares(ctx).await?; - } + State::EqualityCheck { eq } => { + let eq = eq.await.map_err(KeyExchangeError::vm)?; + + if eq != [0u8; 32] { + return Err(KeyExchangeError::share_conversion("PMS values not equal")); + } + + State::Complete + } + state => state, + }; + Ok(()) } } @@ -465,47 +409,46 @@ where async fn compute_ec_shares( ctx: &mut Ctx, role: Role, - converter_0: &mut C0, - converter_1: &mut C1, + converter_0: Arc>, + converter_1: Arc>, private_key: SecretKey, server_key: PublicKey, ) -> Result<(P256, P256), KeyExchangeError> where Ctx: Context, - C0: ShareConvert + Flush + Send, + C0: ShareConvert + Flush + Send + 'static, >::Future: Send, >::Future: Send, - C1: ShareConvert + Flush + Send, + C1: ShareConvert + Flush + Send + 'static, >::Future: Send, >::Future: Send, { // Compute the leader's/follower's share of the pre-master secret. // - // We need to mimic the [diffie-hellman](p256::ecdh::diffie_hellman) function without the - // [SharedSecret](p256::ecdh::SharedSecret) wrapper, because this makes it harder to get the - // result as an EC curve point. + // We need to mimic the [diffie-hellman](p256::ecdh::diffie_hellman) function + // without the [SharedSecret](p256::ecdh::SharedSecret) wrapper, because + // this makes it harder to get the result as an EC curve point. let shared_secret = { let public_projective = server_key.to_projective(); (public_projective * private_key.to_nonzero_scalar().as_ref()).to_affine() }; let encoded_point = EncodedPoint::from(PublicKey::from_affine(shared_secret)?); - let pms_share_0 = derive_x_coord_share(ctx, role, converter_0, encoded_point).await?; - let pms_share_1 = derive_x_coord_share(ctx, role, converter_1, encoded_point).await?; - - // TODO: Fix lifetimes here - //let (pms_share_0, pms_share_1) = ctx - // .try_join( - // |ctx| { - // async { derive_x_coord_share(ctx, role, converter_0, encoded_point).await } - // .scope_boxed() - // }, - // |ctx| { - // async { derive_x_coord_share(ctx, role, converter_1, encoded_point).await } - // .scope_boxed() - // }, - // ) - // .await??; + + let mut converter_0 = converter_0.try_lock_owned().unwrap(); + let mut converter_1 = converter_1.try_lock_owned().unwrap(); + let (pms_share_0, pms_share_1) = ctx + .try_join( + move |ctx| { + async move { derive_x_coord_share(ctx, role, &mut *converter_0, encoded_point).await } + .scope_boxed() + }, + move |ctx| { + async move { derive_x_coord_share(ctx, role, &mut *converter_1, encoded_point).await } + .scope_boxed() + }, + ) + .await??; Ok((pms_share_0, pms_share_1)) } @@ -529,13 +472,6 @@ mod tests { use rand_core::SeedableRng; impl MpcKeyExchange { - fn set_private_key(&mut self, key: SecretKey) { - let State::Initialized { private_key } = &mut self.state else { - panic!("Can only set private key in initialized state") - }; - *private_key = key; - } - fn set_pms_0(&mut self, pms: P256) { let State::ComputedECShares { pms_0, .. } = &mut self.state else { panic!("Can only set private key in initialized state") @@ -556,32 +492,24 @@ mod tests { let (mut leader, mut follower) = create_pair(); - KeyExchange::>::alloc(&mut leader).unwrap(); - KeyExchange::>::alloc(&mut follower).unwrap(); + leader.alloc().unwrap(); + follower.alloc().unwrap(); - leader.set_private_key(leader_private_key.clone()); - follower.set_private_key(follower_private_key.clone()); + leader.private_key = leader_private_key.clone(); + follower.private_key = follower_private_key.clone(); leader.setup(&mut gen).unwrap(); follower.setup(&mut ev).unwrap(); - tokio::try_join!( + tokio::join!( async { leader.flush(&mut ctx_a).await.unwrap(); - let client_public_key = - KeyExchange::>::client_key(&leader).unwrap(); + let client_public_key = leader.client_key().unwrap(); - KeyExchange::>::set_server_key( - &mut leader, - server_public_key, - ) - .unwrap(); + leader.set_server_key(server_public_key).unwrap(); - assert_eq!( - KeyExchange::>::server_key(&leader).unwrap(), - server_public_key - ); + assert_eq!(leader.server_key().unwrap(), server_public_key); let expected_client_public_key = PublicKey::from_affine( (leader_private_key.public_key().to_projective() @@ -592,11 +520,12 @@ mod tests { assert_eq!(client_public_key, expected_client_public_key); leader.flush(&mut ctx_a).await.unwrap(); - Ok(()) }, - follower.flush(&mut ctx_b) - ) - .unwrap(); + async { + follower.flush(&mut ctx_b).await.unwrap(); + follower.flush(&mut ctx_b).await.unwrap(); + } + ); } #[tokio::test] @@ -612,11 +541,11 @@ mod tests { let (mut leader, mut follower) = create_pair(); - KeyExchange::>::alloc(&mut leader).unwrap(); - KeyExchange::>::alloc(&mut follower).unwrap(); + leader.alloc().unwrap(); + follower.alloc().unwrap(); - leader.set_private_key(leader_private_key.clone()); - follower.set_private_key(follower_private_key.clone()); + leader.private_key = leader_private_key.clone(); + follower.private_key = follower_private_key.clone(); let leader_pms = leader.setup(&mut gen).unwrap().into_value(); let leader_pms = gen.decode(leader_pms).unwrap(); @@ -624,32 +553,23 @@ mod tests { let follower_pms = follower.setup(&mut ev).unwrap().into_value(); let follower_pms = ev.decode(follower_pms).unwrap(); - tokio::try_join!( + tokio::join!( async { leader.flush(&mut ctx_a).await.unwrap(); - let _client_public_key = - KeyExchange::>::client_key(&leader).unwrap(); - - KeyExchange::>::set_server_key( - &mut leader, - server_public_key, - ) - .unwrap(); - assert_eq!( - KeyExchange::>::server_key(&leader).unwrap(), - server_public_key - ); + let _client_public_key = leader.client_key().unwrap(); + leader.set_server_key(server_public_key).unwrap(); leader.flush(&mut ctx_a).await.unwrap(); - Ok(()) }, - follower.flush(&mut ctx_b) - ) - .unwrap(); + async { + follower.flush(&mut ctx_b).await.unwrap(); + follower.flush(&mut ctx_b).await.unwrap(); + } + ); - let eq_check_leader = leader.compute_pms(&mut gen).unwrap(); - let eq_check_follower = follower.compute_pms(&mut ev).unwrap(); + leader.compute_pms(&mut gen).unwrap(); + follower.compute_pms(&mut ev).unwrap(); - tokio::try_join!( + tokio::join!( async { gen.flush(&mut ctx_a).await.unwrap(); gen.execute(&mut ctx_a).await.unwrap(); @@ -657,7 +577,6 @@ mod tests { .await .map_err(KeyExchangeError::vm) .unwrap(); - eq_check_leader.check().await }, async { ev.flush(&mut ctx_b).await.unwrap(); @@ -666,10 +585,8 @@ mod tests { .await .map_err(KeyExchangeError::vm) .unwrap(); - eq_check_follower.check().await } - ) - .unwrap(); + ); let (leader_pms, follower_pms) = tokio::try_join!(leader_pms, follower_pms).unwrap(); assert_eq!(leader_pms, follower_pms); @@ -679,8 +596,13 @@ mod tests { async fn test_compute_ec_shares() { let mut rng = ChaCha12Rng::from_seed([0_u8; 32]); let (mut ctx_leader, mut ctx_follower) = test_st_executor(8); - let (mut leader_converter_0, mut follower_converter_0) = ideal_share_convert(Block::ZERO); - let (mut follower_converter_1, mut leader_converter_1) = ideal_share_convert(Block::ZERO); + let (leader_converter_0, follower_converter_0) = ideal_share_convert(Block::ZERO); + let (follower_converter_1, leader_converter_1) = ideal_share_convert(Block::ZERO); + + let leader_converter_0 = Arc::new(Mutex::new(leader_converter_0)); + let leader_converter_1 = Arc::new(Mutex::new(leader_converter_1)); + let follower_converter_0 = Arc::new(Mutex::new(follower_converter_0)); + let follower_converter_1 = Arc::new(Mutex::new(follower_converter_1)); let leader_private_key = SecretKey::random(&mut rng); let follower_private_key = SecretKey::random(&mut rng); @@ -699,16 +621,16 @@ mod tests { compute_ec_shares( &mut ctx_leader, Role::Leader, - &mut leader_converter_0, - &mut leader_converter_1, + leader_converter_0, + leader_converter_1, leader_private_key, server_public_key ), compute_ec_shares( &mut ctx_follower, Role::Follower, - &mut follower_converter_0, - &mut follower_converter_1, + follower_converter_0, + follower_converter_1, follower_private_key, server_public_key ) @@ -744,43 +666,34 @@ mod tests { let (mut leader, mut follower) = create_pair(); - KeyExchange::>::alloc(&mut leader).unwrap(); - KeyExchange::>::alloc(&mut follower).unwrap(); + leader.alloc().unwrap(); + follower.alloc().unwrap(); - leader.set_private_key(leader_private_key.clone()); - follower.set_private_key(follower_private_key.clone()); + leader.private_key = leader_private_key.clone(); + follower.private_key = follower_private_key.clone(); leader.setup(&mut gen).unwrap(); follower.setup(&mut ev).unwrap(); - tokio::try_join!( + tokio::join!( async { leader.flush(&mut ctx_a).await.unwrap(); - let _client_public_key = - KeyExchange::>::client_key(&leader).unwrap(); - - KeyExchange::>::set_server_key( - &mut leader, - server_public_key, - ) - .unwrap(); - assert_eq!( - KeyExchange::>::server_key(&leader).unwrap(), - server_public_key - ); + let _client_public_key = leader.client_key().unwrap(); + leader.set_server_key(server_public_key).unwrap(); leader.flush(&mut ctx_a).await.unwrap(); - Ok(()) }, - follower.flush(&mut ctx_b) - ) - .unwrap(); + async { + follower.flush(&mut ctx_b).await.unwrap(); + follower.flush(&mut ctx_b).await.unwrap(); + } + ); // Now manipulate pms leader.set_pms_0(P256::one()); follower.set_pms_0(P256::one()); - let eq_check_leader = leader.compute_pms(&mut gen).unwrap(); - let eq_check_follower = follower.compute_pms(&mut ev).unwrap(); + leader.compute_pms(&mut gen).unwrap(); + follower.compute_pms(&mut ev).unwrap(); let (leader_res, follower_res) = tokio::join!( async { @@ -790,7 +703,7 @@ mod tests { .await .map_err(KeyExchangeError::vm) .unwrap(); - eq_check_leader.check().await + leader.flush(&mut ctx_a).await }, async { ev.flush(&mut ctx_b).await.unwrap(); @@ -799,7 +712,7 @@ mod tests { .await .map_err(KeyExchangeError::vm) .unwrap(); - eq_check_follower.check().await + follower.flush(&mut ctx_b).await } ); diff --git a/crates/components/key-exchange/src/lib.rs b/crates/components/key-exchange/src/lib.rs index 8bf8092c1..e1c824db1 100644 --- a/crates/components/key-exchange/src/lib.rs +++ b/crates/components/key-exchange/src/lib.rs @@ -28,8 +28,11 @@ pub use config::{ pub use error::KeyExchangeError; pub use exchange::MpcKeyExchange; -use mpz_core::bitvec::BitVec; -use mpz_memory_core::{binary::U8, Array, DecodeFutureTyped}; +use mpz_memory_core::{ + binary::{Binary, U8}, + Array, Memory, View, +}; +use mpz_vm_core::Vm; use p256::PublicKey; /// Pre-master secret. @@ -48,28 +51,8 @@ impl Pms { } } -/// Checks that both parties behaved honestly. -#[must_use] -#[derive(Debug)] -pub struct EqualityCheck(DecodeFutureTyped); - -impl EqualityCheck { - /// Checks that the PMS computation succeeded and that both parties agree on the PMS value. - /// - /// This MUST be called to ensure that no party cheated. - pub async fn check(self) -> Result<(), KeyExchangeError> { - let eq = self.0.await.map_err(KeyExchangeError::vm)?; - - // Eq should be all zeros if pms_1 == pms_2. - if eq != [0u8; 32] { - return Err(KeyExchangeError::share_conversion("PMS values not equal")); - } - Ok(()) - } -} - /// A trait for the 3-party key exchange protocol. -pub trait KeyExchange { +pub trait KeyExchange { /// Allocate necessary computational resources. fn alloc(&mut self) -> Result<(), KeyExchangeError>; @@ -87,11 +70,15 @@ pub trait KeyExchange { fn client_key(&self) -> Result; /// Performs any necessary one-time setup, returning a reference to the PMS. - fn setup(&mut self, vm: &mut V) -> Result; + fn setup(&mut self, vm: &mut V) -> Result + where + V: Vm + Memory + View; /// Computes the PMS, and returns an equality check. /// - /// The equality check makes sure that both parties arrived at the same result. This MUST be - /// called to prevent malicious behavior! - fn compute_pms(&mut self, vm: &mut V) -> Result; + /// The equality check makes sure that both parties arrived at the same + /// result. This MUST be called to prevent malicious behavior! + fn compute_pms(&mut self, vm: &mut V) -> Result<(), KeyExchangeError> + where + V: Vm + Memory + View; } diff --git a/crates/components/key-exchange/src/mock.rs b/crates/components/key-exchange/src/mock.rs index 68693ff18..f525f7056 100644 --- a/crates/components/key-exchange/src/mock.rs +++ b/crates/components/key-exchange/src/mock.rs @@ -55,7 +55,7 @@ mod tests { fn test_mock_is_ke() { let (leader, follower) = create_mock_key_exchange_pair(); - fn is_key_exchange, Ctx, V>(_: T) {} + fn is_key_exchange(_: T) {} is_key_exchange::< MpcKeyExchange, IdealShareConvertReceiver>, From 3b8760fc7ade650ca8323b6414613f9f3b0c8a6d Mon Sep 17 00:00:00 2001 From: sinu <65924192+sinui0@users.noreply.github.com> Date: Tue, 21 Jan 2025 15:59:13 -0800 Subject: [PATCH 4/5] delete old msg module --- crates/components/key-exchange/src/msg.rs | 47 ----------------------- 1 file changed, 47 deletions(-) delete mode 100644 crates/components/key-exchange/src/msg.rs diff --git a/crates/components/key-exchange/src/msg.rs b/crates/components/key-exchange/src/msg.rs deleted file mode 100644 index ddb1f2469..000000000 --- a/crates/components/key-exchange/src/msg.rs +++ /dev/null @@ -1,47 +0,0 @@ -//! This module contains the message types exchanged between the prover and the TLS verifier. - -use std::fmt::{self, Display, Formatter}; - -use p256::{elliptic_curve::sec1::ToEncodedPoint, PublicKey as P256PublicKey}; -use serde::{Deserialize, Serialize}; - -/// A type for messages exchanged between the prover and the TLS verifier during the key exchange -/// protocol. -#[derive(Debug, Clone, Serialize, Deserialize)] -#[allow(missing_docs)] -pub enum KeyExchangeMessage { - FollowerPublicKey(PublicKey), - ServerPublicKey(PublicKey), -} - -/// A wrapper for a serialized public key. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct PublicKey { - /// The sec1 serialized public key. - pub key: Vec, -} - -/// An error that can occur during parsing of a public key. -#[derive(Debug, thiserror::Error)] -pub struct KeyParseError(#[from] p256::elliptic_curve::Error); - -impl Display for KeyParseError { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - write!(f, "Unable to parse public key: {}", self.0) - } -} - -impl From for PublicKey { - fn from(value: P256PublicKey) -> Self { - let key = value.to_encoded_point(false).as_bytes().to_vec(); - PublicKey { key } - } -} - -impl TryFrom for P256PublicKey { - type Error = KeyParseError; - - fn try_from(value: PublicKey) -> Result { - P256PublicKey::from_sec1_bytes(&value.key).map_err(Into::into) - } -} From c9f806f644718625320513ef598815f13474abe6 Mon Sep 17 00:00:00 2001 From: sinu <65924192+sinui0@users.noreply.github.com> Date: Tue, 21 Jan 2025 16:01:11 -0800 Subject: [PATCH 5/5] clean up error --- crates/components/key-exchange/src/error.rs | 31 ++++++--------------- 1 file changed, 9 insertions(+), 22 deletions(-) diff --git a/crates/components/key-exchange/src/error.rs b/crates/components/key-exchange/src/error.rs index a91a213de..24e64dbc3 100644 --- a/crates/components/key-exchange/src/error.rs +++ b/crates/components/key-exchange/src/error.rs @@ -1,4 +1,4 @@ -use std::{error::Error, fmt::Display}; +use std::error::Error; /// MPC-TLS protocol error. #[derive(Debug, thiserror::Error)] @@ -6,37 +6,24 @@ use std::{error::Error, fmt::Display}; pub struct KeyExchangeError(#[from] ErrorRepr); #[derive(Debug, thiserror::Error)] +#[error("key exchange error: {0}")] pub(crate) enum ErrorRepr { - /// An unexpected state was encountered + #[error("state error: {0}")] State(Box), - /// Context error. + #[error("context error: {0}")] Ctx(Box), - /// IO related error + #[error("io error: {0}")] Io(std::io::Error), - /// Virtual machine error + #[error("vm error: {0}")] Vm(Box), - /// Share conversion error + #[error("share conversion error: {0}")] ShareConversion(Box), - /// Role error + #[error("role error: {0}")] Role(Box), - /// Key error + #[error("key error: {0}")] Key(Box), } -impl Display for ErrorRepr { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - ErrorRepr::State(error) => write!(f, "{error}"), - ErrorRepr::Ctx(error) => write!(f, "{error}"), - ErrorRepr::Io(error) => write!(f, "{error}"), - ErrorRepr::Vm(error) => write!(f, "{error}"), - ErrorRepr::ShareConversion(error) => write!(f, "{error}"), - ErrorRepr::Role(error) => write!(f, "{error}"), - ErrorRepr::Key(error) => write!(f, "{error}"), - } - } -} - impl KeyExchangeError { pub(crate) fn state(err: E) -> KeyExchangeError where