diff --git a/arbiter-engine/src/agent.rs b/arbiter-engine/src/agent.rs index 9ad86c7f..e862e9e0 100644 --- a/arbiter-engine/src/agent.rs +++ b/arbiter-engine/src/agent.rs @@ -144,8 +144,9 @@ impl Agent { pub(crate) async fn run(&mut self, instruction: MachineInstruction) { let behavior_engines = self.behavior_engines.take().unwrap(); let behavior_tasks = join_all(behavior_engines.into_iter().map(|mut engine| { + let instruction_clone = instruction.clone(); tokio::spawn(async move { - engine.execute(instruction).await; + engine.execute(instruction_clone).await; engine }) })); @@ -164,10 +165,14 @@ impl StateMachine for Agent { #[tracing::instrument(skip(self), fields(id = self.id))] async fn execute(&mut self, instruction: MachineInstruction) { match instruction { - MachineInstruction::Sync => { + MachineInstruction::Sync(_, _) => { debug!("Agent is syncing."); self.state = State::Syncing; - self.run(instruction).await; + self.run(MachineInstruction::Sync( + self.messager.clone(), + Some(self.client.clone()), + )) + .await; } MachineInstruction::Start => { debug!("Agent is starting up."); diff --git a/arbiter-engine/src/examples/timed_message.rs b/arbiter-engine/src/examples/timed_message.rs index 74d95748..df76c0c8 100644 --- a/arbiter-engine/src/examples/timed_message.rs +++ b/arbiter-engine/src/examples/timed_message.rs @@ -19,23 +19,42 @@ struct TimedMessage { delay: u64, receive_data: String, send_data: String, - messager: Messager, + messager: Option, count: u64, max_count: Option, } +impl TimedMessage { + pub fn new( + delay: u64, + receive_data: String, + send_data: String, + max_count: Option, + ) -> Self { + Self { + delay, + receive_data, + send_data, + messager: None, + count: 0, + max_count, + } + } +} + #[async_trait::async_trait] impl Behavior for TimedMessage { async fn process(&mut self, event: Message) -> Option { trace!("Processing event."); + let messager = self.messager.as_ref().unwrap(); if event.data == self.receive_data { trace!("Event matches message. Sending a new message."); let message = Message { - from: self.messager.id.clone().unwrap(), + from: messager.id.clone().unwrap(), to: To::All, data: self.send_data.clone(), }; - self.messager.send(message).await; + messager.send(message).await; self.count += 1; } if self.count == self.max_count.unwrap_or(u64::MAX) { @@ -48,8 +67,9 @@ impl Behavior for TimedMessage { None } - async fn sync(&mut self) { + async fn sync(&mut self, messager: Messager, _client: Arc) { trace!("Syncing state for `TimedMessage`."); + self.messager = Some(messager); tokio::time::sleep(std::time::Duration::from_secs(self.delay)).await; trace!("Synced state for `TimedMessage`."); } @@ -63,24 +83,15 @@ impl Behavior for TimedMessage { #[tokio::test(flavor = "multi_thread", worker_threads = 4)] async fn echoer() { - // std::env::set_var("RUST_LOG", "trace"); - // tracing_subscriber::fmt::init(); - let mut world = World::new("world"); let agent = Agent::new(AGENT_ID, &world); - let behavior = TimedMessage { - delay: 1, - receive_data: "Hello, world!".to_owned(), - send_data: "Hello, world!".to_owned(), - messager: agent - .messager - .as_ref() - .unwrap() - .join_with_id(Some(AGENT_ID.to_owned())), - count: 0, - max_count: Some(2), - }; + let behavior = TimedMessage::new( + 1, + "Hello, world!".to_owned(), + "Hello, world!".to_owned(), + Some(2), + ); world.add_agent(agent.with_behavior(behavior)); let messager = world.messager.join_with_id(Some("god".to_owned())); @@ -115,37 +126,11 @@ async fn echoer() { #[tokio::test(flavor = "multi_thread", worker_threads = 4)] async fn ping_pong() { - // std::env::set_var("RUST_LOG", "trace"); - // tracing_subscriber::fmt::init(); - let mut world = World::new("world"); let agent = Agent::new(AGENT_ID, &world); - let behavior_ping = TimedMessage { - delay: 1, - receive_data: "pong".to_owned(), - send_data: "ping".to_owned(), - messager: agent - .messager - .as_ref() - .unwrap() - .join_with_id(Some(AGENT_ID.to_owned())), - count: 0, - max_count: Some(2), - }; - let behavior_pong = TimedMessage { - delay: 1, - receive_data: "ping".to_owned(), - send_data: "pong".to_owned(), - messager: agent - .messager - .as_ref() - .unwrap() - .join_with_id(Some(AGENT_ID.to_owned())), - count: 0, - max_count: Some(2), - }; - + let behavior_ping = TimedMessage::new(1, "pong".to_owned(), "ping".to_owned(), Some(2)); + let behavior_pong = TimedMessage::new(1, "ping".to_owned(), "pong".to_owned(), Some(2)); world.add_agent( agent .with_behavior(behavior_ping) @@ -185,38 +170,13 @@ async fn ping_pong() { #[tokio::test(flavor = "multi_thread", worker_threads = 4)] async fn ping_pong_two_agent() { - // std::env::set_var("RUST_LOG", "trace"); - // tracing_subscriber::fmt::init(); - let mut world = World::new("world"); let agent_ping = Agent::new("agent_ping", &world); - let behavior_ping = TimedMessage { - delay: 1, - receive_data: "pong".to_owned(), - send_data: "ping".to_owned(), - messager: agent_ping - .messager - .as_ref() - .unwrap() - .join_with_id(Some("agent_ping".to_owned())), - count: 0, - max_count: Some(2), - }; + let behavior_ping = TimedMessage::new(1, "pong".to_owned(), "ping".to_owned(), Some(2)); let agent_pong = Agent::new("agent_pong", &world); - let behavior_pong = TimedMessage { - delay: 1, - receive_data: "ping".to_owned(), - send_data: "pong".to_owned(), - messager: agent_pong - .messager - .as_ref() - .unwrap() - .join_with_id(Some("agent_pong".to_owned())), - count: 0, - max_count: Some(2), - }; + let behavior_pong = TimedMessage::new(1, "ping".to_owned(), "pong".to_owned(), Some(2)); world.add_agent(agent_ping.with_behavior(behavior_ping)); world.add_agent(agent_pong.with_behavior(behavior_pong)); diff --git a/arbiter-engine/src/examples/token_minter.rs b/arbiter-engine/src/examples/token_minter.rs index f75f7ef4..230bacae 100644 --- a/arbiter-engine/src/examples/token_minter.rs +++ b/arbiter-engine/src/examples/token_minter.rs @@ -33,10 +33,9 @@ pub struct TokenAdmin { pub tokens: Option>>, - // TODO: We should not have to have a client or a messager put here - // explicitly, they should come from the Agent the behavior is given to. - pub client: Arc, - pub messager: Messager, + pub client: Option>, + + pub messager: Option, count: u64, @@ -44,17 +43,12 @@ pub struct TokenAdmin { } impl TokenAdmin { - pub fn new( - client: Arc, - messager: Messager, - count: u64, - max_count: Option, - ) -> Self { + pub fn new(count: u64, max_count: Option) -> Self { Self { token_data: HashMap::new(), tokens: None, - client, - messager, + client: None, + messager: None, count, max_count, } @@ -99,12 +93,11 @@ pub struct MintRequest { #[async_trait::async_trait] impl Behavior for TokenAdmin { - #[tracing::instrument(skip(self), fields(id = -self.messager.id.as_deref()))] - async fn sync(&mut self) { + #[tracing::instrument(skip(self), fields(id = messager.id.as_deref()))] + async fn sync(&mut self, messager: Messager, client: Arc) { for token_data in self.token_data.values_mut() { let token = ArbiterToken::deploy( - self.client.clone(), + client.clone(), ( token_data.name.clone(), token_data.symbol.clone(), @@ -121,10 +114,11 @@ self.messager.id.as_deref()))] .insert(token_data.name.clone(), token.clone()); debug!("Deployed token: {:?}", token); } + self.messager = Some(messager); + self.client = Some(client); } - #[tracing::instrument(skip(self), fields(id = -self.messager.id.as_deref()))] + #[tracing::instrument(skip(self), fields(id = self.messager.as_ref().unwrap().id.as_deref()))] async fn process(&mut self, event: Message) -> Option { if self.tokens.is_none() { error!( @@ -135,6 +129,7 @@ the token admin before running the simulation." let query: TokenAdminQuery = serde_json::from_str(&event.data).unwrap(); trace!("Got query: {:?}", query); + let messager = self.messager.as_ref().unwrap(); match query { TokenAdminQuery::AddressOf(token_name) => { trace!( @@ -143,11 +138,11 @@ the token admin before running the simulation." ); let token_data = self.token_data.get(&token_name).unwrap(); let message = Message { - from: self.messager.id.clone().unwrap(), + from: messager.id.clone().unwrap(), to: To::Agent(event.from.clone()), // Reply back to sender data: serde_json::to_string(token_data).unwrap(), }; - self.messager.send(message).await; + messager.send(message).await; } TokenAdminQuery::MintRequest(mint_request) => { trace!("Minting tokens: {:?}", mint_request); @@ -186,10 +181,10 @@ pub struct TokenRequester { pub request_to: String, /// Client to have an address to receive token mint to and check balance - pub client: Arc, + pub client: Option>, /// The messaging layer for the token requester. - pub messager: Messager, + pub messager: Option, pub count: u64, @@ -211,8 +206,8 @@ impl TokenRequester { address: None, }, request_to: TOKEN_ADMIN_ID.to_owned(), - client, - messager, + client: None, + messager: None, count, max_count, } @@ -221,23 +216,29 @@ impl TokenRequester { #[async_trait::async_trait] impl Behavior for TokenRequester { - #[tracing::instrument(skip(self), fields(id = -self.messager.id.as_deref()))] + #[tracing::instrument(skip(self), fields(id = messager.id.as_deref()))] + async fn sync(&mut self, messager: Messager, client: Arc) { + self.messager = Some(messager); + self.client = Some(client); + } + + #[tracing::instrument(skip(self), fields(id = self.messager.as_ref().unwrap().id.as_deref()))] async fn startup(&mut self) { + let messager = self.messager.as_ref().unwrap(); trace!("Requesting address of token: {:?}", self.token_data.name); let message = Message { - from: self.messager.id.clone().unwrap(), + from: messager.id.clone().unwrap(), to: To::Agent(self.request_to.clone()), data: serde_json::to_string(&TokenAdminQuery::AddressOf(self.token_data.name.clone())) .unwrap(), }; - self.messager.send(message).await; + messager.send(message).await; } - #[tracing::instrument(skip(self), fields(id = -self.messager.id.as_deref()))] + #[tracing::instrument(skip(self), fields(id = self.messager.as_ref().unwrap().id.as_deref()))] async fn process(&mut self, event: Message) -> Option { if let Ok(token_data) = serde_json::from_str::(&event.data) { + let messager = self.messager.as_ref().unwrap(); trace!( "Got token data: {:?}", @@ -249,16 +250,16 @@ token: {:?}", self.token_data.name ); let message = Message { - from: self.messager.id.clone().unwrap(), + from: messager.id.clone().unwrap(), to: To::Agent(self.request_to.clone()), data: serde_json::to_string(&TokenAdminQuery::MintRequest(MintRequest { token: self.token_data.name.clone(), - mint_to: self.client.address(), + mint_to: self.client.as_ref().unwrap().address(), mint_amount: 1, })) .unwrap(), }; - self.messager.send(message).await; + messager.send(message).await; } Some(MachineHalt) } @@ -266,9 +267,14 @@ token: {:?}", #[async_trait::async_trait] impl Behavior for TokenRequester { - #[tracing::instrument(skip(self), fields(id = -self.messager.id.as_deref()))] + async fn sync(&mut self, messager: Messager, client: Arc) { + self.client = Some(client); + self.messager = Some(messager); + } + + #[tracing::instrument(skip(self), fields(id = self.messager.as_ref().unwrap().id.as_deref()))] async fn process(&mut self, event: arbiter_token::TransferFilter) -> Option { + let messager = self.messager.as_ref().unwrap(); trace!( "Got event for `TokenRequester` logger: {:?}", @@ -276,16 +282,16 @@ self.messager.id.as_deref()))] ); std::thread::sleep(std::time::Duration::from_secs(1)); let message = Message { - from: self.messager.id.clone().unwrap(), + from: messager.id.clone().unwrap(), to: To::Agent(self.request_to.clone()), data: serde_json::to_string(&TokenAdminQuery::MintRequest(MintRequest { token: self.token_data.name.clone(), - mint_to: self.client.address(), + mint_to: self.client.as_ref().unwrap().address(), mint_amount: 1, })) .unwrap(), }; - self.messager.send(message).await; + messager.send(message).await; self.count += 1; if self.count == self.max_count.unwrap_or(u64::MAX) { warn!("Reached max count. Halting behavior."); @@ -298,23 +304,11 @@ self.messager.id.as_deref()))] #[ignore] #[tokio::test(flavor = "multi_thread", worker_threads = 4)] async fn token_minter_simulation() { - // std::env::set_var("RUST_LOG", "trace"); - // tracing_subscriber::fmt::init(); - let mut world = World::new("test_world"); // Create the token admin agent let token_admin = Agent::new(TOKEN_ADMIN_ID, &world); - let mut token_admin_behavior = TokenAdmin::new( - token_admin.client.clone(), - token_admin - .messager - .as_ref() - .unwrap() - .join_with_id(Some(TOKEN_ADMIN_ID.to_owned())), - 0, - Some(4), - ); + let mut token_admin_behavior = TokenAdmin::new(0, Some(4)); token_admin_behavior.add_token(TokenData { name: TOKEN_NAME.to_owned(), symbol: TOKEN_SYMBOL.to_owned(), diff --git a/arbiter-engine/src/machine.rs b/arbiter-engine/src/machine.rs index c16cd623..856f42c9 100644 --- a/arbiter-engine/src/machine.rs +++ b/arbiter-engine/src/machine.rs @@ -1,18 +1,30 @@ //! The [`StateMachine`] trait, [`Behavior`] trait, and the [`Engine`] that runs //! [`Behavior`]s. -use std::fmt::Debug; +// TODO: Notes +// I think we should have the `sync` stage of the behavior receive the client +// and messager and then the user can decide if it wants to use those in their +// behavior. +// Could typestate pattern help here at all? Sync could produce a `Synced` state +// behavior that can then not have options for client and messager. Then the +// user can decide if they want to use those in their behavior and get a bit +// simpler UX. + +use std::{fmt::Debug, sync::Arc}; + +use arbiter_core::middleware::RevmMiddleware; use serde::de::DeserializeOwned; use tokio::sync::broadcast::Receiver; +use self::messager::Messager; use super::*; /// The instructions that can be sent to a [`StateMachine`]. -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Debug)] pub enum MachineInstruction { /// Used to make a [`StateMachine`] sync with the world. - Sync, + Sync(Option, Option>), /// Used to make a [`StateMachine`] start up. Start, @@ -71,7 +83,7 @@ pub enum State { pub trait Behavior: Send + Sync + 'static { /// Used to bring the agent back up to date with the latest state of the /// world. This could be used if the world was stopped and later restarted. - async fn sync(&mut self) {} + async fn sync(&mut self, _messager: Messager, _client: Arc) {} /// Used to start the agent. /// This is where the agent can engage in its specific start up activities @@ -139,12 +151,12 @@ where { async fn execute(&mut self, instruction: MachineInstruction) { match instruction { - MachineInstruction::Sync => { + MachineInstruction::Sync(messager, client) => { trace!("Behavior is syncing."); self.state = State::Syncing; let mut behavior = self.behavior.take().unwrap(); let behavior_task = tokio::spawn(async move { - behavior.sync().await; + behavior.sync(messager.unwrap(), client.unwrap()).await; behavior }); self.behavior = Some(behavior_task.await.unwrap()); diff --git a/arbiter-engine/src/messager.rs b/arbiter-engine/src/messager.rs index 290fd6f3..041b7ccd 100644 --- a/arbiter-engine/src/messager.rs +++ b/arbiter-engine/src/messager.rs @@ -40,6 +40,16 @@ pub struct Messager { broadcast_receiver: Option>, } +impl Clone for Messager { + fn clone(&self) -> Self { + Self { + broadcast_sender: self.broadcast_sender.clone(), + broadcast_receiver: Some(self.broadcast_sender.subscribe()), + id: self.id.clone(), + } + } +} + impl Messager { // TODO: Allow for modulating the capacity of the messager. // TODO: It might be nice to have some kind of messaging header so that we can diff --git a/arbiter-engine/src/world.rs b/arbiter-engine/src/world.rs index d4d33c2f..e435e165 100644 --- a/arbiter-engine/src/world.rs +++ b/arbiter-engine/src/world.rs @@ -107,7 +107,7 @@ impl World { /// Runs the world through up to the [`State::Processing`] stage. pub async fn run(&mut self) { - self.execute(MachineInstruction::Sync).await; + self.execute(MachineInstruction::Sync(None, None)).await; self.execute(MachineInstruction::Start).await; self.execute(MachineInstruction::Process).await; } @@ -137,13 +137,14 @@ impl World { impl StateMachine for World { async fn execute(&mut self, instruction: MachineInstruction) { match instruction { - MachineInstruction::Sync => { + MachineInstruction::Sync(_, _) => { info!("World is syncing."); self.state = State::Syncing; let agents = self.agents.take().unwrap(); let agent_tasks = join_all(agents.into_values().map(|mut agent| { + let instruction_clone = instruction.clone(); tokio::spawn(async move { - agent.execute(instruction).await; + agent.execute(instruction_clone).await; agent }) })); @@ -163,8 +164,9 @@ impl StateMachine for World { self.state = State::Starting; let agents = self.agents.take().unwrap(); let agent_tasks = join_all(agents.into_values().map(|mut agent| { + let instruction_clone = instruction.clone(); tokio::spawn(async move { - agent.execute(instruction).await; + agent.execute(instruction_clone).await; agent }) })); @@ -186,8 +188,9 @@ impl StateMachine for World { let mut agent_distributors = vec![]; let agent_processors = join_all(agents.into_values().map(|mut agent| { agent_distributors.push(agent.distributor.0.clone()); + let instruction_clone = instruction.clone(); tokio::spawn(async move { - agent.execute(instruction).await; + agent.execute(instruction_clone).await; agent }) }));