Skip to content

Commit

Permalink
Updated coordinator logging
Browse files Browse the repository at this point in the history
  • Loading branch information
0xKitsune committed Feb 6, 2024
1 parent 2a3252e commit 8f9b1fd
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 47 deletions.
152 changes: 107 additions & 45 deletions src/coordinator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,15 @@ pub struct Coordinator {

impl Coordinator {
pub async fn new(config: CoordinatorConfig) -> eyre::Result<Self> {
tracing::info!("Initializing coordinator");

let database = Arc::new(CoordinatorDb::new(&config.db).await?);

tracing::info!("Fetching masks from database");
let masks = database.fetch_masks(0).await?;
let masks = Arc::new(Mutex::new(masks));

tracing::info!("Initializing SQS client");
let sqs_client = Arc::new(sqs_client_from_config(&config.aws).await?);

Ok(Self {
Expand All @@ -57,12 +61,14 @@ impl Coordinator {
}

pub async fn spawn(self: Arc<Self>) -> eyre::Result<()> {
tracing::info!("Starting coordinator");

tracing::info!("Spawning coordinator");
let mut tasks = FuturesUnordered::new();

// TODO: Error handling
tracing::info!("Spawning uniqueness check");
tasks.push(tokio::spawn(self.clone().handle_uniqueness_check()));

tracing::info!("Spawning db sync");
tasks.push(tokio::spawn(self.clone().handle_db_sync()));

while let Some(result) = tasks.next().await {
Expand All @@ -81,30 +87,49 @@ impl Coordinator {
&self.config.queues.shares_queue_url,
)
.await?;
//TODO: Add metric for number of messages in queue

for message in messages {
let receipt_handle = message
.receipt_handle
.context("Missing receipt handle in message")?;

let body = message.body.context("Missing message body")?;

let template: Template = serde_json::from_str(&body)
.context("Failed to parse message")?;

tracing::info!(?template, "Query received");
//TODO: Add the id comm for better observability
tracing::info!(?receipt_handle, "Processing message");

// Sync all new masks that have been added to the database
self.sync_masks().await?;

//TODO: Add the id comm
tracing::info!(
?receipt_handle,
"Sending query to participants"
);
let streams =
self.send_query_to_participants(&template).await?;

let mut handles = vec![];
let mut handles = Vec::with_capacity(2);

//TODO: Add the id comm
tracing::info!(?receipt_handle, "Computing denominators");
let (denominator_rx, denominator_handle) =
self.compute_denominators(template.mask);

handles.push(denominator_handle);

//TODO: Add the id comm
tracing::info!(
?receipt_handle,
"Processing participant shares"
);


//KIT:TODO:
let (batch_process_shares_rx, batch_process_shares_handle) =
self.batch_process_participant_shares(
denominator_rx,
Expand All @@ -113,16 +138,22 @@ impl Coordinator {

handles.push(batch_process_shares_handle);

//TODO: Add the id comm
tracing::info!(?receipt_handle, "Processing results");

//KIT:TODO:
let distance_results =
self.process_results(batch_process_shares_rx).await?;

tracing::info!(?receipt_handle, "Enqueuing results");
sqs_enqueue(
&self.sqs_client,
&self.config.queues.distances_queue_url,
&distance_results,
)
.await?;

tracing::info!(?receipt_handle, "Deleting message from queue");
sqs_delete_message(
&self.sqs_client,
&self.config.queues.shares_queue_url,
Expand All @@ -138,28 +169,41 @@ impl Coordinator {
}
}

#[tracing::instrument(skip(self, query))]
pub async fn send_query_to_participants(
&self,
query: &Template,
) -> eyre::Result<Vec<BufReader<TcpStream>>> {
// Write each share to the corresponding participant
let streams =
future::try_join_all(self.participants.iter().enumerate().map(
|(i, participant_host)| async move {
tracing::info!(
participant = i,
?participant_host,
"Connecting to participant"
);
let mut stream =
TcpStream::connect(participant_host).await?;

let streams = future::try_join_all(self.participants.iter().map(
|participant_host| async move {
let mut stream = TcpStream::connect(participant_host).await?;
tracing::info!(?participant_host, "Connected to participant");
stream.write_all(bytemuck::bytes_of(query)).await?;

stream.write_all(bytemuck::bytes_of(query)).await?;
tracing::info!(?query, "Query sent to participant");
//TODO: Add the id comm for better observability
tracing::info!(
participant = i,
?participant_host,
"Query sent to participant"
);

Ok::<_, eyre::Report>(BufReader::new(stream))
},
))
.await?;
Ok::<_, eyre::Report>(BufReader::new(stream))
},
))
.await?;

Ok(streams)
}

#[tracing::instrument(skip(self))]
pub fn compute_denominators(
&self,
mask: Bits,
Expand All @@ -169,17 +213,22 @@ impl Coordinator {

let denominator_handle = tokio::task::spawn_blocking(move || {
let masks = masks.blocking_lock();

let masks: &[Bits] = bytemuck::cast_slice(&masks);
let engine = MasksEngine::new(&mask);
let total_masks = masks.len();

tracing::info!(?mask, "Computing denominators");
tracing::info!("Processing denominators");

for chunk in masks.chunks(BATCH_SIZE) {
for (i, chunk) in masks.chunks(BATCH_SIZE).enumerate() {
let mut result = vec![[0_u16; 31]; chunk.len()];
engine.batch_process(&mut result, chunk);

tracing::debug!(masks_processed = (i+1) * BATCH_SIZE, ?total_masks, "Denominator batch processed");
sender.blocking_send(result)?;
}

tracing::info!("Denominators processed");

Ok(())
});

Expand All @@ -200,38 +249,47 @@ impl Coordinator {
let batch_worker = tokio::task::spawn(async move {
loop {
// Collect futures of denominator and share batches
let streams_future = future::try_join_all(
streams.iter_mut().map(|stream| async move {
let mut batch = vec![[0_u16; 31]; BATCH_SIZE];
let mut buffer: &mut [u8] =
bytemuck::cast_slice_mut(batch.as_mut_slice());

// We can not use read_exact here as we might get EOF before the
// buffer is full But we should
// still try to fill the entire buffer.
// If nothing else, this guarantees that we read batches at a
// [u16;31] boundary.
while !buffer.is_empty() {
tracing::info!("Reading buffer");

let bytes_read =
stream.read_buf(&mut buffer).await?;

tracing::info!("Buffer read");

if bytes_read == 0 {
let n_incomplete = (buffer.len()
let streams_future =
future::try_join_all(streams.iter_mut().enumerate().map(
|(i, stream)| async move {
let mut batch = vec![[0_u16; 31]; BATCH_SIZE];
let mut buffer: &mut [u8] =
bytemuck::cast_slice_mut(batch.as_mut_slice());

// We can not use read_exact here as we might get EOF before the
// buffer is full But we should
// still try to fill the entire buffer.
// If nothing else, this guarantees that we read batches at a
// [u16;31] boundary.
while !buffer.is_empty() {
let bytes_read =
stream.read_buf(&mut buffer).await?;

tracing::debug!(
participant = i,
bytes_read,
"Bytes read from participant"
);

if bytes_read == 0 {
let n_incomplete = (buffer.len()
+ std::mem::size_of::<[u16; 31]>() //TODO: make this a const
- 1)
/ std::mem::size_of::<[u16; 31]>(); //TODO: make this a const
batch.truncate(batch.len() - n_incomplete);
break;
/ std::mem::size_of::<[u16; 31]>(); //TODO: make this a const
batch.truncate(batch.len() - n_incomplete);
break;
}
}
}

Ok::<_, eyre::Report>(batch)
}),
);
tracing::info!(
participant = i,
batch_size = batch.len(),
"Participant batch received"
);

Ok::<_, eyre::Report>(batch)
},
));

// Wait on all parts concurrently
let (denom, shares) =
Expand All @@ -252,6 +310,9 @@ impl Coordinator {
.iter_mut()
.for_each(|batch| batch.truncate(batch_size));

tracing::info(?batch_size, "Batch processed");


// Send batches
processed_shares_tx.send((denom, shares)).await?;
if batch_size == 0 {
Expand Down Expand Up @@ -359,6 +420,7 @@ impl Coordinator {
Ok(distance_results)
}

#[tracing::instrument(skip(self))]
async fn sync_masks(&self) -> eyre::Result<()> {
let mut masks = self.masks.lock().await;
let next_mask_number = masks.len();
Expand Down
8 changes: 8 additions & 0 deletions src/db/coordinator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,28 @@ pub struct CoordinatorDb {

impl CoordinatorDb {
pub async fn new(config: &DbConfig) -> eyre::Result<Self> {
tracing::info!("Connecting to database");

if config.create
&& !sqlx::Postgres::database_exists(&config.url).await?
{
tracing::info!("Creating database");
sqlx::Postgres::create_database(&config.url).await?;
}

let pool = sqlx::Pool::connect(&config.url).await?;

if config.migrate {
tracing::info!("Running migrations");
MIGRATOR.run(&pool).await?;
}

tracing::info!("Connected to database");

Ok(Self { pool })
}

#[tracing::instrument(skip(self))]
pub async fn fetch_masks(&self, id: usize) -> eyre::Result<Vec<Bits>> {
let masks: Vec<(Bits,)> = sqlx::query_as(
r#"
Expand All @@ -43,6 +50,7 @@ impl CoordinatorDb {
Ok(masks.into_iter().map(|(mask,)| mask).collect())
}

#[tracing::instrument(skip(self))]
pub async fn insert_masks(
&self,
masks: &[(u64, Bits)],
Expand Down
21 changes: 19 additions & 2 deletions src/utils/aws.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::fmt::Debug;

use aws_config::Region;
use aws_sdk_sqs::types::Message;
use eyre::Context;
Expand Down Expand Up @@ -26,6 +28,7 @@ pub async fn sqs_client_from_config(
Ok(aws_client)
}

#[tracing::instrument(skip(client, queue_url))]
pub async fn sqs_dequeue(
client: &aws_sdk_sqs::Client,
queue_url: &str,
Expand All @@ -41,16 +44,24 @@ pub async fn sqs_dequeue(
return Ok(vec![]);
};

let message_receipts = messages
.iter()
.map(|message| message.receipt_handle.clone())
.collect::<Vec<Option<String>>>();

tracing::info!(?message_receipts, "Dequeued messages");

Ok(messages)
}

#[tracing::instrument(skip(client, queue_url, message))]
pub async fn sqs_enqueue<T>(
client: &aws_sdk_sqs::Client,
queue_url: &str,
message: T,
) -> eyre::Result<()>
where
T: Serialize,
T: Serialize + Debug,
{
let body = serde_json::to_string(&message)
.wrap_err("Failed to serialize message")?;
Expand All @@ -62,6 +73,8 @@ where
.send()
.await?;

tracing::info!(?message, "Enqueued message");

Ok(())
}

Expand All @@ -70,12 +83,16 @@ pub async fn sqs_delete_message(
queue_url: impl Into<String>,
receipt_handle: impl Into<String>,
) -> eyre::Result<()> {
let receipt_handle = receipt_handle.into();

client
.delete_message()
.queue_url(queue_url)
.receipt_handle(receipt_handle)
.receipt_handle(&receipt_handle)
.send()
.await?;

tracing::info!(?receipt_handle, "Deleted message from queue");

Ok(())
}

0 comments on commit 8f9b1fd

Please sign in to comment.