Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(alpha.8) - Refactor prf crate #684

Merged
merged 4 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions crates/components/hmac-sha256/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,23 @@ mock = []
[dependencies]
tlsn-hmac-sha256-circuits = { workspace = true }

mpz-garble = { workspace = true }
mpz-vm-core = { workspace = true }
mpz-circuits = { workspace = true }
mpz-common = { workspace = true }
mpz-common = { workspace = true, features = ["cpu"] }

async-trait = { workspace = true }
derive_builder = { workspace = true }
futures = { workspace = true }
thiserror = { workspace = true }
tracing = { workspace = true }
futures = { workspace = true }

[dev-dependencies]
criterion = { workspace = true, features = ["async_tokio"] }
mpz-common = { workspace = true, features = ["test-utils"] }
mpz-ot = { workspace = true, features = ["ideal"] }
mpz-garble = { workspace = true }
mpz-common = { workspace = true, features = ["test-utils"] }

criterion = { workspace = true, features = ["async_tokio"] }
tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] }
rand = { workspace = true }

[[bench]]
name = "prf"
Expand Down
230 changes: 86 additions & 144 deletions crates/components/hmac-sha256/benches/prf.rs
Original file line number Diff line number Diff line change
@@ -1,188 +1,130 @@
#![allow(clippy::let_underscore_future)]

use criterion::{criterion_group, criterion_main, Criterion};

use hmac_sha256::{MpcPrf, Prf, PrfConfig, Role};
use mpz_common::executor::test_mt_executor;
use mpz_garble::{config::Role as DEAPRole, protocol::deap::DEAPThread, Memory};
use mpz_ot::ideal::ot::ideal_ot;
use mpz_common::executor::{mt::MTConfig, test_mt_executor};
use mpz_garble::protocol::semihonest::{Evaluator, Generator};
use mpz_ot::ideal::cot::ideal_cot;
use mpz_vm_core::{
memory::{binary::U8, correlated::Delta, Array},
prelude::*,
};
use rand::{rngs::StdRng, SeedableRng};

#[allow(clippy::unit_arg)]
fn criterion_benchmark(c: &mut Criterion) {
let mut group = c.benchmark_group("prf");
group.sample_size(10);
let rt = tokio::runtime::Runtime::new().unwrap();

group.bench_function("prf_preprocess", |b| b.to_async(&rt).iter(preprocess));
group.bench_function("prf", |b| b.to_async(&rt).iter(prf));
}

criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);

async fn preprocess() {
let (mut leader_exec, mut follower_exec) = test_mt_executor(128);
async fn prf() {
let mut rng = StdRng::seed_from_u64(0);

let (leader_ot_send_0, follower_ot_recv_0) = ideal_ot();
let (follower_ot_send_0, leader_ot_recv_0) = ideal_ot();
let (leader_ot_send_1, follower_ot_recv_1) = ideal_ot();
let (follower_ot_send_1, leader_ot_recv_1) = ideal_ot();
let pms = [42u8; 32];
let client_random = [69u8; 32];
let server_random: [u8; 32] = [96u8; 32];

let leader_thread_0 = DEAPThread::new(
DEAPRole::Leader,
[0u8; 32],
leader_exec.new_thread().await.unwrap(),
leader_ot_send_0,
leader_ot_recv_0,
);
let leader_thread_1 = leader_thread_0
.new_thread(
leader_exec.new_thread().await.unwrap(),
leader_ot_send_1,
leader_ot_recv_1,
)
let (mut leader_exec, mut follower_exec) = test_mt_executor(128, MTConfig::default());
let mut leader_ctx = leader_exec.new_thread().await.unwrap();
let mut follower_ctx = follower_exec.new_thread().await.unwrap();

let delta = Delta::random(&mut rng);
let (ot_send, ot_recv) = ideal_cot(delta.into_inner());

let mut leader_vm = Generator::new(ot_send, [0u8; 16], delta);
let mut follower_vm = Evaluator::new(ot_recv);

let leader_pms: Array<U8, 32> = leader_vm.alloc().unwrap();
leader_vm.mark_public(leader_pms).unwrap();
leader_vm.assign(leader_pms, pms).unwrap();
leader_vm.commit(leader_pms).unwrap();

let follower_pms: Array<U8, 32> = follower_vm.alloc().unwrap();
follower_vm.mark_public(follower_pms).unwrap();
follower_vm.assign(follower_pms, pms).unwrap();
follower_vm.commit(follower_pms).unwrap();

let mut leader = MpcPrf::new(PrfConfig::builder().role(Role::Leader).build().unwrap());
let mut follower = MpcPrf::new(PrfConfig::builder().role(Role::Follower).build().unwrap());

let leader_output = leader.setup(&mut leader_vm, leader_pms).unwrap();
let follower_output = follower.setup(&mut follower_vm, follower_pms).unwrap();

leader
.set_client_random(&mut leader_vm, Some(client_random))
.unwrap();
follower.set_client_random(&mut follower_vm, None).unwrap();

let follower_thread_0 = DEAPThread::new(
DEAPRole::Follower,
[0u8; 32],
follower_exec.new_thread().await.unwrap(),
follower_ot_send_0,
follower_ot_recv_0,
);
let follower_thread_1 = follower_thread_0
.new_thread(
follower_exec.new_thread().await.unwrap(),
follower_ot_send_1,
follower_ot_recv_1,
)
leader
.set_server_random(&mut leader_vm, server_random)
.unwrap();
follower
.set_server_random(&mut follower_vm, server_random)
.unwrap();

let leader_pms = leader_thread_0.new_public_input::<[u8; 32]>("pms").unwrap();
let follower_pms = follower_thread_0
.new_public_input::<[u8; 32]>("pms")
let _ = leader_vm
.decode(leader_output.keys.client_write_key)
.unwrap();
let _ = leader_vm
.decode(leader_output.keys.server_write_key)
.unwrap();
let _ = leader_vm.decode(leader_output.keys.client_iv).unwrap();
let _ = leader_vm.decode(leader_output.keys.server_iv).unwrap();

let mut leader = MpcPrf::new(
PrfConfig::builder().role(Role::Leader).build().unwrap(),
leader_thread_0,
leader_thread_1,
);
let mut follower = MpcPrf::new(
PrfConfig::builder().role(Role::Follower).build().unwrap(),
follower_thread_0,
follower_thread_1,
);
let _ = follower_vm
.decode(follower_output.keys.client_write_key)
.unwrap();
let _ = follower_vm
.decode(follower_output.keys.server_write_key)
.unwrap();
let _ = follower_vm.decode(follower_output.keys.client_iv).unwrap();
let _ = follower_vm.decode(follower_output.keys.server_iv).unwrap();

futures::join!(
async {
leader.setup(leader_pms).await.unwrap();
leader.set_client_random(Some([0u8; 32])).await.unwrap();
leader.preprocess().await.unwrap();
leader_vm.flush(&mut leader_ctx).await.unwrap();
leader_vm.execute(&mut leader_ctx).await.unwrap();
leader_vm.flush(&mut leader_ctx).await.unwrap();
},
async {
follower.setup(follower_pms).await.unwrap();
follower.set_client_random(None).await.unwrap();
follower.preprocess().await.unwrap();
follower_vm.flush(&mut follower_ctx).await.unwrap();
follower_vm.execute(&mut follower_ctx).await.unwrap();
follower_vm.flush(&mut follower_ctx).await.unwrap();
}
);
}

async fn prf() {
let (mut leader_exec, mut follower_exec) = test_mt_executor(128);

let (leader_ot_send_0, follower_ot_recv_0) = ideal_ot();
let (follower_ot_send_0, leader_ot_recv_0) = ideal_ot();
let (leader_ot_send_1, follower_ot_recv_1) = ideal_ot();
let (follower_ot_send_1, leader_ot_recv_1) = ideal_ot();

let leader_thread_0 = DEAPThread::new(
DEAPRole::Leader,
[0u8; 32],
leader_exec.new_thread().await.unwrap(),
leader_ot_send_0,
leader_ot_recv_0,
);
let leader_thread_1 = leader_thread_0
.new_thread(
leader_exec.new_thread().await.unwrap(),
leader_ot_send_1,
leader_ot_recv_1,
)
.unwrap();
let cf_hs_hash = [1u8; 32];
let sf_hs_hash = [2u8; 32];

let follower_thread_0 = DEAPThread::new(
DEAPRole::Follower,
[0u8; 32],
follower_exec.new_thread().await.unwrap(),
follower_ot_send_0,
follower_ot_recv_0,
);
let follower_thread_1 = follower_thread_0
.new_thread(
follower_exec.new_thread().await.unwrap(),
follower_ot_send_1,
follower_ot_recv_1,
)
.unwrap();
leader.set_cf_hash(&mut leader_vm, cf_hs_hash).unwrap();
leader.set_sf_hash(&mut leader_vm, sf_hs_hash).unwrap();

let leader_pms = leader_thread_0.new_public_input::<[u8; 32]>("pms").unwrap();
let follower_pms = follower_thread_0
.new_public_input::<[u8; 32]>("pms")
.unwrap();
follower.set_cf_hash(&mut follower_vm, cf_hs_hash).unwrap();
follower.set_sf_hash(&mut follower_vm, sf_hs_hash).unwrap();

let mut leader = MpcPrf::new(
PrfConfig::builder().role(Role::Leader).build().unwrap(),
leader_thread_0,
leader_thread_1,
);
let mut follower = MpcPrf::new(
PrfConfig::builder().role(Role::Follower).build().unwrap(),
follower_thread_0,
follower_thread_1,
);
let _ = leader_vm.decode(leader_output.cf_vd).unwrap();
let _ = leader_vm.decode(leader_output.sf_vd).unwrap();

let pms = [42u8; 32];
let client_random = [0u8; 32];
let server_random = [1u8; 32];
let cf_hs_hash = [2u8; 32];
let sf_hs_hash = [3u8; 32];
let _ = follower_vm.decode(follower_output.cf_vd).unwrap();
let _ = follower_vm.decode(follower_output.sf_vd).unwrap();

futures::join!(
async {
leader.setup(leader_pms.clone()).await.unwrap();
leader.set_client_random(Some(client_random)).await.unwrap();
leader.preprocess().await.unwrap();
leader_vm.flush(&mut leader_ctx).await.unwrap();
leader_vm.execute(&mut leader_ctx).await.unwrap();
leader_vm.flush(&mut leader_ctx).await.unwrap();
},
async {
follower.setup(follower_pms.clone()).await.unwrap();
follower.set_client_random(None).await.unwrap();
follower.preprocess().await.unwrap();
follower_vm.flush(&mut follower_ctx).await.unwrap();
follower_vm.execute(&mut follower_ctx).await.unwrap();
follower_vm.flush(&mut follower_ctx).await.unwrap();
}
);

leader.thread_mut().assign(&leader_pms, pms).unwrap();
follower.thread_mut().assign(&follower_pms, pms).unwrap();

let (_leader_keys, _follower_keys) = futures::try_join!(
leader.compute_session_keys(server_random),
follower.compute_session_keys(server_random)
)
.unwrap();

let _ = futures::try_join!(
leader.compute_client_finished_vd(cf_hs_hash),
follower.compute_client_finished_vd(cf_hs_hash)
)
.unwrap();

let _ = futures::try_join!(
leader.compute_server_finished_vd(sf_hs_hash),
follower.compute_server_finished_vd(sf_hs_hash)
)
.unwrap();

futures::try_join!(
leader.thread_mut().finalize(),
follower.thread_mut().finalize()
)
.unwrap();
}
26 changes: 6 additions & 20 deletions crates/components/hmac-sha256/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ impl PrfError {
source: Some(msg.into().into()),
}
}

pub(crate) fn vm<E: Into<Box<dyn Error + Send + Sync>>>(err: E) -> Self {
Self::new(ErrorKind::Vm, err)
}
}

#[derive(Debug)]
Expand All @@ -58,26 +62,8 @@ impl fmt::Display for PrfError {
}
}

impl From<mpz_garble::MemoryError> for PrfError {
fn from(error: mpz_garble::MemoryError) -> Self {
Self::new(ErrorKind::Vm, error)
}
}

impl From<mpz_garble::LoadError> for PrfError {
fn from(error: mpz_garble::LoadError) -> Self {
Self::new(ErrorKind::Vm, error)
}
}

impl From<mpz_garble::ExecutionError> for PrfError {
fn from(error: mpz_garble::ExecutionError) -> Self {
Self::new(ErrorKind::Vm, error)
}
}

impl From<mpz_garble::DecodeError> for PrfError {
fn from(error: mpz_garble::DecodeError) -> Self {
impl From<mpz_common::ContextError> for PrfError {
fn from(error: mpz_common::ContextError) -> Self {
Self::new(ErrorKind::Vm, error)
}
}
Loading
Loading