Skip to content

Commit

Permalink
Upgrade libprio to alpha prerelease
Browse files Browse the repository at this point in the history
  • Loading branch information
divergentdave committed Jan 25, 2025
1 parent f96fe5e commit 535c3fe
Show file tree
Hide file tree
Showing 12 changed files with 120 additions and 61 deletions.
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 1 addition & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,7 @@ postgres-types = "0.2.8"
pretty_assertions = "1.4.1"
# Disable default features so that individual workspace crates can choose to
# re-enable them
# TODO(#3436): switch to a released version of libprio, once there is a released version implementing VDAF-13
# prio = { version = "0.16.7", default-features = false, features = ["experimental"] }
prio = { git = "https://github.com/divviup/libprio-rs", rev = "3c1aeb30c661d373566749a81589fc0a4045f89a", default-features = false, features = ["experimental"] }
prio = { version = "0.17.0-alpha.0", default-features = false, features = ["experimental"] }
prometheus = "0.13.4"
querystring = "1.1.0"
quickcheck = { version = "1.0.3", default-features = false }
Expand Down
2 changes: 1 addition & 1 deletion aggregator/src/aggregator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -829,7 +829,7 @@ impl<C: Clock> TaskAggregator<C> {
}

VdafInstance::Prio3Sum { max_measurement } => {
let vdaf = Prio3::new_sum(2, u128::from(*max_measurement))?;
let vdaf = Prio3::new_sum(2, *max_measurement)?;
let verify_key = task.vdaf_verify_key()?;
VdafOps::Prio3Sum(Arc::new(vdaf), verify_key)
}
Expand Down
10 changes: 7 additions & 3 deletions aggregator/src/aggregator/aggregation_job_creator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ impl<C: Clock + 'static> AggregationJobCreator<C> {
}

(task::BatchMode::TimeInterval, VdafInstance::Prio3Sum { max_measurement }) => {
let vdaf = Arc::new(Prio3::new_sum(2, u128::from(*max_measurement))?);
let vdaf = Arc::new(Prio3::new_sum(2, *max_measurement)?);
self.create_aggregation_jobs_for_time_interval_task_no_param::<VERIFY_KEY_LENGTH, Prio3Sum>(task, vdaf)
.await
}
Expand Down Expand Up @@ -402,7 +402,11 @@ impl<C: Clock + 'static> AggregationJobCreator<C> {
VdafInstance::Prio3Count,
) => {
let vdaf: Arc<
Prio3<prio::flp::types::Count<Field64>, vdaf::xof::XofTurboShake128, 16>,
Prio3<
prio::flp::types::Count<Field64>,
vdaf::xof::XofTurboShake128,
VERIFY_KEY_LENGTH,
>,
> = Arc::new(Prio3::new_count(2)?);
let batch_time_window_size = *batch_time_window_size;
self.create_aggregation_jobs_for_leader_selected_task_no_param::<
Expand All @@ -417,7 +421,7 @@ impl<C: Clock + 'static> AggregationJobCreator<C> {
},
VdafInstance::Prio3Sum { max_measurement },
) => {
let vdaf = Arc::new(Prio3::new_sum(2, u128::from(*max_measurement))?);
let vdaf = Arc::new(Prio3::new_sum(2, *max_measurement)?);
let batch_time_window_size = *batch_time_window_size;
self.create_aggregation_jobs_for_leader_selected_task_no_param::<
VERIFY_KEY_LENGTH,
Expand Down
51 changes: 43 additions & 8 deletions aggregator_api/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use janus_core::{
hpke::HpkeKeypair,
test_util::install_test_trace_subscriber,
time::MockClock,
vdaf::{vdaf_dp_strategies, VdafInstance},
vdaf::{vdaf_dp_strategies, VdafInstance, VERIFY_KEY_LENGTH},
};
use janus_messages::{
Duration, HpkeAeadId, HpkeConfig, HpkeConfigId, HpkeKdfId, HpkeKemId, HpkePublicKey, Role,
Expand Down Expand Up @@ -183,7 +183,12 @@ async fn post_task_bad_role() {
// Setup: create a datastore & handler.
let (handler, _ephemeral_datastore, _) = setup_api_test().await;

let vdaf_verify_key = SecretBytes::new(thread_rng().sample_iter(Standard).take(16).collect());
let vdaf_verify_key = SecretBytes::new(
thread_rng()
.sample_iter(Standard)
.take(VERIFY_KEY_LENGTH)
.collect(),
);
let aggregator_auth_token = AuthenticationToken::DapAuth(random());

let req = PostTaskReq {
Expand Down Expand Up @@ -217,7 +222,12 @@ async fn post_task_unauthorized() {
// Setup: create a datastore & handler.
let (handler, _ephemeral_datastore, _) = setup_api_test().await;

let vdaf_verify_key = SecretBytes::new(thread_rng().sample_iter(Standard).take(16).collect());
let vdaf_verify_key = SecretBytes::new(
thread_rng()
.sample_iter(Standard)
.take(VERIFY_KEY_LENGTH)
.collect(),
);
let aggregator_auth_token = AuthenticationToken::DapAuth(random());

let req = PostTaskReq {
Expand Down Expand Up @@ -252,7 +262,12 @@ async fn post_task_helper_no_optional_fields() {
// Setup: create a datastore & handler.
let (handler, _ephemeral_datastore, ds) = setup_api_test().await;

let vdaf_verify_key = SecretBytes::new(thread_rng().sample_iter(Standard).take(16).collect());
let vdaf_verify_key = SecretBytes::new(
thread_rng()
.sample_iter(Standard)
.take(VERIFY_KEY_LENGTH)
.collect(),
);

// Verify: posting a task creates a new task which matches the request.
let req = PostTaskReq {
Expand Down Expand Up @@ -332,7 +347,12 @@ async fn post_task_helper_with_aggregator_auth_token() {
// Setup: create a datastore & handler.
let (handler, _ephemeral_datastore, _) = setup_api_test().await;

let vdaf_verify_key = SecretBytes::new(thread_rng().sample_iter(Standard).take(16).collect());
let vdaf_verify_key = SecretBytes::new(
thread_rng()
.sample_iter(Standard)
.take(VERIFY_KEY_LENGTH)
.collect(),
);
let aggregator_auth_token = AuthenticationToken::DapAuth(random());

// Verify: posting a task with role = helper and an aggregator auth token fails
Expand Down Expand Up @@ -368,7 +388,12 @@ async fn post_task_idempotence() {
let (handler, ephemeral_datastore, _) = setup_api_test().await;
let ds = ephemeral_datastore.datastore(MockClock::default()).await;

let vdaf_verify_key = SecretBytes::new(thread_rng().sample_iter(Standard).take(16).collect());
let vdaf_verify_key = SecretBytes::new(
thread_rng()
.sample_iter(Standard)
.take(VERIFY_KEY_LENGTH)
.collect(),
);
let aggregator_auth_token = AuthenticationToken::DapAuth(random());

// Verify: posting a task creates a new task which matches the request.
Expand Down Expand Up @@ -442,7 +467,12 @@ async fn post_task_leader_all_optional_fields() {
// Setup: create a datastore & handler.
let (handler, _ephemeral_datastore, ds) = setup_api_test().await;

let vdaf_verify_key = SecretBytes::new(thread_rng().sample_iter(Standard).take(16).collect());
let vdaf_verify_key = SecretBytes::new(
thread_rng()
.sample_iter(Standard)
.take(VERIFY_KEY_LENGTH)
.collect(),
);
let aggregator_auth_token = AuthenticationToken::DapAuth(random());
let collector_auth_token_hash = AuthenticationTokenHash::from(&random());
// Verify: posting a task creates a new task which matches the request.
Expand Down Expand Up @@ -522,7 +552,12 @@ async fn post_task_leader_no_aggregator_auth_token() {
// Setup: create a datastore & handler.
let (handler, _ephemeral_datastore, _) = setup_api_test().await;

let vdaf_verify_key = SecretBytes::new(thread_rng().sample_iter(Standard).take(16).collect());
let vdaf_verify_key = SecretBytes::new(
thread_rng()
.sample_iter(Standard)
.take(VERIFY_KEY_LENGTH)
.collect(),
);

// Verify: posting a task with role = Leader and no aggregator auth token fails
let req = PostTaskReq {
Expand Down
2 changes: 1 addition & 1 deletion core/src/dp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ impl AggregatorWithNoise<0, 16, NoDifferentialPrivacy> for dummy::Vdaf {
}

// identity strategy implementations for vdafs from libprio
impl TypeWithNoise<NoDifferentialPrivacy> for prio::flp::types::Sum<Field128> {
impl TypeWithNoise<NoDifferentialPrivacy> for prio::flp::types::Sum<Field64> {
fn add_noise_to_result(
&self,
_dp_strategy: &NoDifferentialPrivacy,
Expand Down
4 changes: 2 additions & 2 deletions core/src/vdaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use std::str;

/// The length of the verify key parameter for Prio3 VDAF instantiations using
/// [`XofTurboShake128`][prio::vdaf::xof::XofTurboShake128].
pub const VERIFY_KEY_LENGTH: usize = 16;
pub const VERIFY_KEY_LENGTH: usize = 32;

/// Private use algorithm ID for a customized version of Prio3SumVec. This value was chosen for
/// interoperability with Daphne.
Expand Down Expand Up @@ -265,7 +265,7 @@ macro_rules! vdaf_dispatch_impl_base {
}

::janus_core::vdaf::VdafInstance::Prio3Sum { max_measurement } => {
let $vdaf = ::prio::vdaf::prio3::Prio3::new_sum(2, *max_measurement as u128)?;
let $vdaf = ::prio::vdaf::prio3::Prio3::new_sum(2, *max_measurement)?;
type $Vdaf = ::prio::vdaf::prio3::Prio3Sum;
const $VERIFY_KEY_LEN: usize = ::janus_core::vdaf::VERIFY_KEY_LENGTH;
type $DpStrategy = janus_core::dp::NoDifferentialPrivacy;
Expand Down
2 changes: 1 addition & 1 deletion integration_tests/tests/integration/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ pub async fn submit_measurements_and_verify_aggregate(
.await;
}
VdafInstance::Prio3Sum { max_measurement } => {
let max_measurement = u128::from(*max_measurement);
let max_measurement = *max_measurement;
let vdaf = Prio3::new_sum(2, max_measurement).unwrap();

let measurements: Vec<_> =
Expand Down
82 changes: 52 additions & 30 deletions integration_tests/tests/integration/simulation/bad_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ use janus_messages::{
};
use prio::{
codec::{Decode, Encode, ParameterizedDecode},
field::{Field128, FieldElement},
flp::{gadgets::ParallelSum, types::Histogram, Type},
field::{random_vector, Field128, FieldElement},
flp::{gadgets::ParallelSum, types::Histogram, Flp},
vdaf::{
prio3::{optimal_chunk_length, Prio3, Prio3Histogram, Prio3InputShare, Prio3PublicShare},
xof::{Seed, Xof, XofTurboShake128},
AggregateShare, Aggregator, Client as _, Collector, PrepareTransition, Vdaf,
},
};
use rand::{distributions::Standard, random, Rng};
use rand::{distributions::Standard, random, thread_rng, Rng};
use std::{net::Ipv4Addr, sync::Arc};
use tokio::net::TcpListener;
use trillium_tokio::Stopper;
Expand Down Expand Up @@ -130,7 +130,7 @@ fn shard_encoded_measurement(
task_id: &TaskId,
encoded_measurement: Vec<Field128>,
report_id: ReportId,
) -> (Prio3PublicShare<16>, Vec<Prio3InputShare<Field128, 16>>) {
) -> (Prio3PublicShare<32>, Vec<Prio3InputShare<Field128, 32>>) {
const DST_MEASUREMENT_SHARE: u16 = 1;
const DST_PROOF_SHARE: u16 = 2;
const DST_JOINT_RANDOMNESS: u16 = 3;
Expand All @@ -142,6 +142,8 @@ fn shard_encoded_measurement(

const NUM_PROOFS: u8 = 1;

let mut rng = thread_rng();

let ctx = vdaf_application_context(task_id);

assert_eq!(encoded_measurement.len(), MAX_REPORTS);
Expand All @@ -150,11 +152,11 @@ fn shard_encoded_measurement(
Histogram::new(MAX_REPORTS, chunk_length).unwrap();

// Share measurement.
let helper_measurement_share_seed = Seed::<16>::generate().unwrap();
let helper_share_seed = rng.gen::<Seed<32>>();
let mut helper_measurement_share_rng = XofTurboShake128::seed_stream(
&helper_measurement_share_seed,
&vdaf.domain_separation_tag(DST_MEASUREMENT_SHARE, &ctx),
&[HELPER_AGGREGATOR_ID],
helper_share_seed.as_ref(),
&[&domain_separation_tag(vdaf, DST_MEASUREMENT_SHARE), &ctx],
&[&[HELPER_AGGREGATOR_ID]],
);
let mut expanded_helper_measurement_share: Vec<Field128> =
Vec::with_capacity(circuit.input_len());
Expand All @@ -166,10 +168,10 @@ fn shard_encoded_measurement(
}

// Derive joint randomness.
let helper_joint_rand_blind = Seed::<16>::generate().unwrap();
let helper_joint_rand_blind = rng.gen::<Seed<32>>();
let mut helper_joint_rand_part_xof = XofTurboShake128::init(
helper_joint_rand_blind.as_ref(),
&vdaf.domain_separation_tag(DST_JOINT_RAND_PART, &ctx),
&[&domain_separation_tag(vdaf, DST_JOINT_RAND_PART), &ctx],
);
helper_joint_rand_part_xof.update(&[HELPER_AGGREGATOR_ID]);
helper_joint_rand_part_xof.update(report_id.as_ref());
Expand All @@ -178,10 +180,10 @@ fn shard_encoded_measurement(
}
let helper_joint_rand_seed_part = helper_joint_rand_part_xof.into_seed();

let leader_joint_rand_blind = Seed::<16>::generate().unwrap();
let leader_joint_rand_blind = rng.gen::<Seed<32>>();
let mut leader_joint_rand_part_xof = XofTurboShake128::init(
leader_joint_rand_blind.as_ref(),
&vdaf.domain_separation_tag(DST_JOINT_RAND_PART, &ctx),
&[&domain_separation_tag(vdaf, DST_JOINT_RAND_PART), &ctx],
);
leader_joint_rand_part_xof.update(&[LEADER_AGGREGATOR_ID]);
leader_joint_rand_part_xof.update(report_id.as_ref());
Expand All @@ -191,17 +193,17 @@ fn shard_encoded_measurement(
let leader_joint_rand_seed_part = leader_joint_rand_part_xof.into_seed();

let mut joint_rand_seed_xof = XofTurboShake128::init(
&[0; 16],
&vdaf.domain_separation_tag(DST_JOINT_RAND_SEED, &ctx),
&[0; 32],
&[&domain_separation_tag(vdaf, DST_JOINT_RAND_SEED), &ctx],
);
joint_rand_seed_xof.update(leader_joint_rand_seed_part.as_ref());
joint_rand_seed_xof.update(helper_joint_rand_seed_part.as_ref());
let joint_rand_seed = joint_rand_seed_xof.into_seed();
let mut joint_rand: Vec<Field128> = Vec::with_capacity(circuit.joint_rand_len());
let mut joint_rand_xof = XofTurboShake128::seed_stream(
&joint_rand_seed,
&vdaf.domain_separation_tag(DST_JOINT_RANDOMNESS, &ctx),
&[NUM_PROOFS],
joint_rand_seed.as_ref(),
&[&domain_separation_tag(vdaf, DST_JOINT_RANDOMNESS), &ctx],
&[&[NUM_PROOFS]],
);
for _ in 0..circuit.joint_rand_len() {
joint_rand.push(joint_rand_xof.sample(Standard));
Expand All @@ -215,14 +217,13 @@ fn shard_encoded_measurement(
let mut leader_proof_share = circuit
.prove(&encoded_measurement, &prove_rand, &joint_rand)
.unwrap();
let helper_proof_share_seed = Seed::<16>::generate().unwrap();
let mut helper_proof_share_xof = XofTurboShake128::seed_stream(
&helper_proof_share_seed,
&vdaf.domain_separation_tag(DST_PROOF_SHARE, &ctx),
&[NUM_PROOFS, HELPER_AGGREGATOR_ID],
let mut helper_proofs_share_xof = XofTurboShake128::seed_stream(
helper_share_seed.as_ref(),
&[&domain_separation_tag(vdaf, DST_PROOF_SHARE), &ctx],
&[&[NUM_PROOFS, HELPER_AGGREGATOR_ID]],
);
for leader_elem in leader_proof_share.iter_mut() {
let helper_elem = helper_proof_share_xof.sample(Standard);
let helper_elem = helper_proofs_share_xof.sample(Standard);
*leader_elem -= helper_elem;
}

Expand Down Expand Up @@ -250,10 +251,7 @@ fn shard_encoded_measurement(
let leader_input_share =
Prio3InputShare::get_decoded_with_param(&(vdaf, 0), &encoded_leader_input_share).unwrap();
let mut encoded_helper_input_share = Vec::new();
helper_measurement_share_seed
.encode(&mut encoded_helper_input_share)
.unwrap();
helper_proof_share_seed
helper_share_seed
.encode(&mut encoded_helper_input_share)
.unwrap();
helper_joint_rand_blind
Expand All @@ -268,11 +266,20 @@ fn shard_encoded_measurement(
)
}

fn domain_separation_tag(vdaf: &impl Vdaf, usage: u16) -> [u8; 8] {
let mut dst = [0; 8];
dst[0] = 12; // version
dst[1] = 0; // algorithm class
dst[2..6].copy_from_slice(vdaf.algorithm_id().to_be_bytes().as_slice());
dst[6..8].copy_from_slice(usage.to_be_bytes().as_slice());
dst
}

async fn prepare_report(
http_client: &reqwest::Client,
task: &Task,
public_share: Prio3PublicShare<16>,
input_shares: Vec<Prio3InputShare<Field128, 16>>,
public_share: Prio3PublicShare<32>,
input_shares: Vec<Prio3InputShare<Field128, 32>>,
report_id: ReportId,
report_time: Time,
) -> Result<Report, janus_client::Error> {
Expand Down Expand Up @@ -452,12 +459,27 @@ fn shard_encoded_measurement_correct() {

let mut encoded_measurement = Vec::from([Field128::zero(); MAX_REPORTS]);
encoded_measurement[0] = Field128::one();

// Check the circuit output first.
let histogram =
Histogram::<Field128, ParallelSum<_, _>>::new(MAX_REPORTS, chunk_length).unwrap();
let joint_rand = random_vector(histogram.joint_rand_len());
let circuit_output = histogram
.valid(
&mut histogram.gadget(),
&encoded_measurement,
&joint_rand,
1,
)
.unwrap();
assert_eq!(circuit_output, [Field128::zero(); 2]);

let task_id = random();
let report_id = random();
let (public_share, input_shares) =
shard_encoded_measurement(&vdaf, &task_id, encoded_measurement, report_id);

let verify_key: [u8; 16] = random();
let verify_key: [u8; 32] = random();
let ctx = vdaf_application_context(&task_id);
let (leader_prepare_state, leader_prepare_share) = vdaf
.prepare_init(
Expand Down
Loading

0 comments on commit 535c3fe

Please sign in to comment.