diff --git a/Cargo.toml b/Cargo.toml index 126cf1897..721b4919c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,6 +44,7 @@ tokio = { version = "1", default-features = false, features = [ "fs", "macros", "signal", + "sync", "rt-multi-thread", "rt", "process", diff --git a/boost_manager/src/main.rs b/boost_manager/src/main.rs index 77f4b4758..adb022086 100644 --- a/boost_manager/src/main.rs +++ b/boost_manager/src/main.rs @@ -139,6 +139,7 @@ impl Server { .add_task(watcher) .add_task(updater) .add_task(purger) + .build() .start() .await } diff --git a/ingest/src/server_iot.rs b/ingest/src/server_iot.rs index fc2f73593..4145944ce 100644 --- a/ingest/src/server_iot.rs +++ b/ingest/src/server_iot.rs @@ -393,6 +393,7 @@ pub async fn grpc_server(settings: &Settings) -> Result<()> { .add_task(beacon_report_sink_server) .add_task(witness_report_sink_server) .add_task(grpc_server) + .build() .start() .await } diff --git a/ingest/src/server_mobile.rs b/ingest/src/server_mobile.rs index a282fdd62..fecd6e2e3 100644 --- a/ingest/src/server_mobile.rs +++ b/ingest/src/server_mobile.rs @@ -454,6 +454,7 @@ pub async fn grpc_server(settings: &Settings) -> Result<()> { .add_task(invalidated_radio_threshold_report_sink_server) .add_task(coverage_object_report_sink_server) .add_task(grpc_server) + .build() .start() .await } diff --git a/ingest/tests/iot_ingest.rs b/ingest/tests/iot_ingest.rs index ab39472c6..5d8d700a5 100644 --- a/ingest/tests/iot_ingest.rs +++ b/ingest/tests/iot_ingest.rs @@ -28,7 +28,11 @@ async fn initialize_session_and_send_beacon_and_witness() { .run_until(async move { tokio::task::spawn_local(async move { let server = create_test_server(port, beacon_client, witness_client, None, None); - TaskManager::builder().add_task(server).start().await + TaskManager::builder() + .add_task(server) + .build() + .start() + .await }); let pub_key = generate_keypair(); @@ -77,7 +81,11 @@ async fn stream_stops_after_incorrectly_signed_init_request() { .run_until(async move { tokio::task::spawn_local(async move { let server = create_test_server(port, beacon_client, witness_client, None, None); - TaskManager::builder().add_task(server).start().await + TaskManager::builder() + .add_task(server) + .build() + .start() + .await }); let pub_key = generate_keypair(); @@ -111,7 +119,11 @@ async fn stream_stops_after_incorrectly_signed_beacon() { .run_until(async move { tokio::task::spawn_local(async move { let server = create_test_server(port, beacon_client, witness_client, None, None); - TaskManager::builder().add_task(server).start().await + TaskManager::builder() + .add_task(server) + .build() + .start() + .await }); let pub_key = generate_keypair(); @@ -148,7 +160,11 @@ async fn stream_stops_after_incorrect_beacon_pubkey() { .run_until(async move { tokio::task::spawn_local(async move { let server = create_test_server(port, beacon_client, witness_client, None, None); - TaskManager::builder().add_task(server).start().await + TaskManager::builder() + .add_task(server) + .build() + .start() + .await }); let pub_key = generate_keypair(); @@ -188,7 +204,11 @@ async fn stream_stops_after_incorrectly_signed_witness() { .run_until(async move { tokio::task::spawn_local(async move { let server = create_test_server(port, beacon_client, witness_client, None, None); - TaskManager::builder().add_task(server).start().await + TaskManager::builder() + .add_task(server) + .build() + .start() + .await }); let pub_key = generate_keypair(); @@ -225,7 +245,11 @@ async fn stream_stops_after_incorrect_witness_pubkey() { .run_until(async move { tokio::task::spawn_local(async move { let server = create_test_server(port, beacon_client, witness_client, None, None); - TaskManager::builder().add_task(server).start().await + TaskManager::builder() + .add_task(server) + .build() + .start() + .await }); let pub_key = generate_keypair(); @@ -265,7 +289,11 @@ async fn stream_stop_if_client_attempts_to_initiliaze_2nd_session() { .run_until(async move { tokio::task::spawn_local(async move { let server = create_test_server(port, beacon_client, witness_client, None, None); - TaskManager::builder().add_task(server).start().await + TaskManager::builder() + .add_task(server) + .build() + .start() + .await }); let pub_key = generate_keypair(); @@ -316,7 +344,11 @@ async fn stream_stops_if_init_not_sent_within_timeout() { tokio::task::spawn_local(async move { let server = create_test_server(port, beacon_client, witness_client, Some(500), None); - TaskManager::builder().add_task(server).start().await + TaskManager::builder() + .add_task(server) + .build() + .start() + .await }); let mut client = connect_and_stream(port).await; @@ -338,7 +370,11 @@ async fn stream_stops_on_session_timeout() { tokio::task::spawn_local(async move { let server = create_test_server(port, beacon_client, witness_client, Some(500), Some(900)); - TaskManager::builder().add_task(server).start().await + TaskManager::builder() + .add_task(server) + .build() + .start() + .await }); let mut client = connect_and_stream(port).await; diff --git a/iot_config/src/main.rs b/iot_config/src/main.rs index 18b0b3870..94d5a0966 100644 --- a/iot_config/src/main.rs +++ b/iot_config/src/main.rs @@ -123,6 +123,7 @@ impl Daemon { TaskManager::builder() .add_task(grpc_server) .add_task(db_cleaner) + .build() .start() .await } diff --git a/iot_packet_verifier/src/daemon.rs b/iot_packet_verifier/src/daemon.rs index 9be5be791..b26d22b3e 100644 --- a/iot_packet_verifier/src/daemon.rs +++ b/iot_packet_verifier/src/daemon.rs @@ -207,6 +207,7 @@ impl Cmd { .add_task(verifier_daemon) .add_task(burner) .add_task(report_files_server) + .build() .start() .await } diff --git a/iot_verifier/src/main.rs b/iot_verifier/src/main.rs index 71f978206..3fd50ab9a 100644 --- a/iot_verifier/src/main.rs +++ b/iot_verifier/src/main.rs @@ -318,6 +318,7 @@ impl Server { .add_task(pk_loader_server) .add_task(entropy_loader_server) .add_task(rewarder) + .build() .start() .await } diff --git a/mobile_config/src/main.rs b/mobile_config/src/main.rs index a904e08f7..bfaf9a9b5 100644 --- a/mobile_config/src/main.rs +++ b/mobile_config/src/main.rs @@ -107,7 +107,11 @@ impl Daemon { hex_boosting_svc, }; - TaskManager::builder().add_task(grpc_server).start().await + TaskManager::builder() + .add_task(grpc_server) + .build() + .start() + .await } } diff --git a/mobile_packet_verifier/src/daemon.rs b/mobile_packet_verifier/src/daemon.rs index 99651cca9..34f35d05d 100644 --- a/mobile_packet_verifier/src/daemon.rs +++ b/mobile_packet_verifier/src/daemon.rs @@ -186,6 +186,7 @@ impl Cmd { .add_task(reports_server) .add_task(event_id_purger) .add_task(daemon) + .build() .start() .await } diff --git a/mobile_verifier/src/cli/server.rs b/mobile_verifier/src/cli/server.rs index 6e88bbdcb..8c9a9c00a 100644 --- a/mobile_verifier/src/cli/server.rs +++ b/mobile_verifier/src/cli/server.rs @@ -389,6 +389,7 @@ impl Cmd { .add_task(radio_threshold_ingest_server) .add_task(invalidated_radio_threshold_ingest_server) .add_task(data_session_ingestor) + .build() .start() .await } diff --git a/price/src/main.rs b/price/src/main.rs index c20fe57c0..3803a62fa 100644 --- a/price/src/main.rs +++ b/price/src/main.rs @@ -110,6 +110,7 @@ impl Server { .add_task(hnt_price_generator) .add_task(mobile_price_generator) .add_task(iot_price_generator) + .build() .start() .await } diff --git a/task_manager/src/lib.rs b/task_manager/src/lib.rs index 0f3de24c2..b7d76194a 100644 --- a/task_manager/src/lib.rs +++ b/task_manager/src/lib.rs @@ -3,7 +3,7 @@ mod select_all; use std::pin::pin; use crate::select_all::select_all; -use futures::{future::LocalBoxFuture, Future, StreamExt}; +use futures::{future::LocalBoxFuture, Future, FutureExt, StreamExt}; use tokio::signal; pub trait ManagedTask { @@ -17,16 +17,25 @@ pub struct TaskManager { tasks: Vec>, } +impl ManagedTask for TaskManager { + fn start_task( + self: Box, + shutdown: triggered::Listener, + ) -> LocalBoxFuture<'static, anyhow::Result<()>> { + Box::pin(self.do_start(Box::pin(shutdown))) + } +} + pub struct TaskManagerBuilder { tasks: Vec>, } -pub struct StopableLocalFuture { +struct StoppableLocalFuture { shutdown_trigger: triggered::Trigger, future: LocalBoxFuture<'static, anyhow::Result<()>>, } -impl Future for StopableLocalFuture { +impl Future for StoppableLocalFuture { type Output = anyhow::Result<()>; fn poll( @@ -71,20 +80,25 @@ impl TaskManager { pub async fn start(self) -> anyhow::Result<()> { let mut sigterm = signal::unix::signal(signal::unix::SignalKind::terminate())?; + let shutdown = Box::pin( + futures::future::select( + Box::pin(async move { sigterm.recv().await }), + Box::pin(signal::ctrl_c()), + ) + .map(|_| ()), + ); + self.do_start(shutdown).await + } - let shutdown_triggers = create_triggers(self.tasks.len()); - - let mut futures = start_futures(shutdown_triggers.clone(), self.tasks); - - let mut shutdown = - futures::future::select(Box::pin(sigterm.recv()), Box::pin(signal::ctrl_c())); + async fn do_start(self, mut shutdown: LocalBoxFuture<'static, ()>) -> anyhow::Result<()> { + let mut futures = start_futures(self.tasks); loop { if futures.is_empty() { break; } - let mut select = select_all(futures.into_iter()); + let mut select = select_all(futures); tokio::select! { _ = &mut shutdown => { @@ -112,30 +126,25 @@ impl TaskManagerBuilder { self } - pub fn start(self) -> impl Future> { - let manager = TaskManager { tasks: self.tasks }; - manager.start() + pub fn build(self) -> TaskManager { + TaskManager { tasks: self.tasks } } } -fn start_futures( - shutdown_triggers: Vec<(triggered::Trigger, triggered::Listener)>, - tasks: Vec>, -) -> Vec { - shutdown_triggers +fn start_futures(tasks: Vec>) -> Vec { + tasks .into_iter() - .zip(tasks) - .map( - |((shutdown_trigger, shutdown_listener), task)| StopableLocalFuture { - shutdown_trigger, - future: task.start_task(shutdown_listener), - }, - ) + .map(|task| { + let (trigger, listener) = triggered::trigger(); + StoppableLocalFuture { + shutdown_trigger: trigger, + future: task.start_task(listener), + } + }) .collect() } -#[allow(clippy::manual_try_fold)] -async fn stop_all(futures: Vec) -> anyhow::Result<()> { +async fn stop_all(futures: Vec) -> anyhow::Result<()> { #[allow(clippy::manual_try_fold)] futures::stream::iter(futures.into_iter().rev()) .then(|local| async move { @@ -148,14 +157,6 @@ async fn stop_all(futures: Vec) -> anyhow::Result<()> { .collect() } -fn create_triggers(n: usize) -> Vec<(triggered::Trigger, triggered::Listener)> { - (0..n).fold(Vec::new(), |mut vec, _| { - let (shutdown_trigger, shutdown_listener) = triggered::trigger(); - vec.push((shutdown_trigger, shutdown_listener)); - vec - }) -} - #[cfg(test)] mod tests { use super::*; @@ -164,10 +165,10 @@ mod tests { use tokio::sync::mpsc; struct TestTask { - id: u64, + name: &'static str, delay: u64, result: anyhow::Result<()>, - sender: mpsc::Sender, + sender: mpsc::Sender<&'static str>, } impl ManagedTask for TestTask { @@ -180,7 +181,7 @@ mod tests { _ = shutdown_listener.clone() => (), _ = tokio::time::sleep(std::time::Duration::from_millis(self.delay)) => (), } - self.sender.send(self.id).await.expect("unable to send"); + self.sender.send(self.name).await.expect("unable to send"); self.result }); @@ -198,22 +199,23 @@ mod tests { let result = TaskManager::builder() .add_task(TestTask { - id: 1, + name: "1", delay: 50, result: Ok(()), sender: sender.clone(), }) .add_task(TestTask { - id: 2, + name: "2", delay: 100, result: Ok(()), sender: sender.clone(), }) + .build() .start() .await; - assert_eq!(Some(1), receiver.recv().await); - assert_eq!(Some(2), receiver.recv().await); + assert_eq!(Some("1"), receiver.recv().await); + assert_eq!(Some("2"), receiver.recv().await); assert!(result.is_ok()); } @@ -223,29 +225,30 @@ mod tests { let result = TaskManager::builder() .add_task(TestTask { - id: 1, + name: "1", delay: 1000, result: Ok(()), sender: sender.clone(), }) .add_task(TestTask { - id: 2, + name: "2", delay: 50, result: Err(anyhow!("error")), sender: sender.clone(), }) .add_task(TestTask { - id: 3, + name: "3", delay: 1000, result: Ok(()), sender: sender.clone(), }) + .build() .start() .await; - assert_eq!(Some(2), receiver.recv().await); - assert_eq!(Some(3), receiver.recv().await); - assert_eq!(Some(1), receiver.recv().await); + assert_eq!(Some("2"), receiver.recv().await); + assert_eq!(Some("3"), receiver.recv().await); + assert_eq!(Some("1"), receiver.recv().await); assert_eq!("error", result.unwrap_err().to_string()); } @@ -255,29 +258,106 @@ mod tests { let result = TaskManager::builder() .add_task(TestTask { - id: 1, + name: "1", delay: 1000, result: Ok(()), sender: sender.clone(), }) .add_task(TestTask { - id: 2, + name: "2", delay: 50, result: Err(anyhow!("error")), sender: sender.clone(), }) .add_task(TestTask { - id: 3, + name: "3", delay: 200, result: Err(anyhow!("second")), sender: sender.clone(), }) + .build() .start() .await; - assert_eq!(Some(2), receiver.recv().await); - assert_eq!(Some(3), receiver.recv().await); - assert_eq!(Some(1), receiver.recv().await); + assert_eq!(Some("2"), receiver.recv().await); + assert_eq!(Some("3"), receiver.recv().await); + assert_eq!(Some("1"), receiver.recv().await); assert_eq!("error", result.unwrap_err().to_string()); } + + #[tokio::test] + async fn nested_tasks_will_stop_parent_then_move_up() { + let (sender, mut receiver) = mpsc::channel(10); + + let result = TaskManager::builder() + .add_task(TestTask { + name: "task-1", + delay: 500, + result: Ok(()), + sender: sender.clone(), + }) + .add_task( + TaskManager::builder() + .add_task(TestTask { + name: "task-2-1", + delay: 500, + result: Ok(()), + sender: sender.clone(), + }) + .add_task(TestTask { + name: "task-2-2", + delay: 100, + result: Err(anyhow!("error")), + sender: sender.clone(), + }) + .add_task(TestTask { + name: "task-2-3", + delay: 500, + result: Ok(()), + sender: sender.clone(), + }) + .add_task(TestTask { + name: "task-2", + delay: 500, + result: Ok(()), + sender: sender.clone(), + }) + .build(), + ) + .add_task( + TaskManager::builder() + .add_task(TestTask { + name: "task-3-1", + delay: 1000, + result: Ok(()), + sender: sender.clone(), + }) + .add_task(TestTask { + name: "task-3-2", + delay: 1000, + result: Ok(()), + sender: sender.clone(), + }) + .add_task(TestTask { + name: "task-3", + delay: 1000, + result: Ok(()), + sender: sender.clone(), + }) + .build(), + ) + .build() + .start() + .await; + + assert_eq!(Some("task-2-2"), receiver.recv().await); + assert_eq!(Some("task-2"), receiver.recv().await); + assert_eq!(Some("task-2-3"), receiver.recv().await); + assert_eq!(Some("task-2-1"), receiver.recv().await); + assert_eq!(Some("task-3"), receiver.recv().await); + assert_eq!(Some("task-3-2"), receiver.recv().await); + assert_eq!(Some("task-3-1"), receiver.recv().await); + assert_eq!(Some("task-1"), receiver.recv().await); + assert!(result.is_err()); + } }