diff --git a/Cargo.toml b/Cargo.toml index ea1daba..883f820 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,10 +4,11 @@ version = "0.0.1" edition = "2021" [features] -default = ["http"] +default = ["http", "anyhow"] http = ["hyper", "http-body-util", "hyper-util", "tokio/net", "tokio/signal", "restate-sdk-shared-core/http"] [dependencies] +anyhow = {version = "1.0", optional = true} bytes = "1.6.1" futures = "0.3" http-body-util = { version = "0.1", optional = true } diff --git a/src/errors.rs b/src/errors.rs index 73017dd..0bf4f96 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -32,6 +32,13 @@ impl StdError for HandlerErrorInner { #[derive(Debug)] pub struct HandlerError(pub(crate) HandlerErrorInner); +impl HandlerError { + #[cfg(feature = "anyhow")] + pub fn from_anyhow(err: anyhow::Error) -> Self { + Self(HandlerErrorInner::Retryable(err.into())) + } +} + impl From for HandlerError { fn from(value: E) -> Self { Self(HandlerErrorInner::Retryable(Box::new(value))) @@ -44,6 +51,7 @@ impl From for HandlerError { } } +// Took from anyhow impl AsRef for HandlerError { fn as_ref(&self) -> &(dyn StdError + Send + Sync + 'static) { &self.0 diff --git a/src/serde.rs b/src/serde.rs index e5f521b..bae9065 100644 --- a/src/serde.rs +++ b/src/serde.rs @@ -139,7 +139,7 @@ impl_serde_primitives!(f64); // --- Json responses -pub struct Json(T); +pub struct Json(pub T); impl Json { pub fn into_inner(self) -> T { diff --git a/test-services/Cargo.toml b/test-services/Cargo.toml index d558786..0b2a469 100644 --- a/test-services/Cargo.toml +++ b/test-services/Cargo.toml @@ -5,8 +5,10 @@ edition = "2021" publish = false [dependencies] +anyhow = "1.0" tokio = { version = "1", features = ["full"] } tracing-subscriber = "0.3" +futures = "0.3" restate-sdk = { path = ".." } serde = { version = "1", features = ["derive"] } tracing = "0.1.40" \ No newline at end of file diff --git a/test-services/exclusions.yaml b/test-services/exclusions.yaml index efad6cd..c9b30bd 100644 --- a/test-services/exclusions.yaml +++ b/test-services/exclusions.yaml @@ -1,49 +1,39 @@ exclusions: "alwaysSuspending": - - "dev.restate.sdktesting.tests.AwaitTimeout" - - "dev.restate.sdktesting.tests.ServiceToServiceCommunication" - - "dev.restate.sdktesting.tests.SideEffect" - - "dev.restate.sdktesting.tests.Sleep" - - "dev.restate.sdktesting.tests.SleepWithFailures" - - "dev.restate.sdktesting.tests.State" - - "dev.restate.sdktesting.tests.UpgradeWithInFlightInvocation" - - "dev.restate.sdktesting.tests.UpgradeWithNewInvocation" - - "dev.restate.sdktesting.tests.UserErrors" - - "dev.restate.sdktesting.tests.WorkflowAPI" + - "dev.restate.sdktesting.tests.AwaitTimeout" + - "dev.restate.sdktesting.tests.ServiceToServiceCommunication" + - "dev.restate.sdktesting.tests.SideEffect" + - "dev.restate.sdktesting.tests.Sleep" + - "dev.restate.sdktesting.tests.SleepWithFailures" + - "dev.restate.sdktesting.tests.UpgradeWithInFlightInvocation" + - "dev.restate.sdktesting.tests.UpgradeWithNewInvocation" + - "dev.restate.sdktesting.tests.UserErrors" + - "dev.restate.sdktesting.tests.WorkflowAPI" "default": - - "dev.restate.sdktesting.tests.AwaitTimeout" - - "dev.restate.sdktesting.tests.CallOrdering" - - "dev.restate.sdktesting.tests.CancelInvocation" - - "dev.restate.sdktesting.tests.Ingress" - - "dev.restate.sdktesting.tests.KafkaIngress" - - "dev.restate.sdktesting.tests.KillInvocation" - - "dev.restate.sdktesting.tests.PrivateService" - - "dev.restate.sdktesting.tests.ServiceToServiceCommunication" - - "dev.restate.sdktesting.tests.Sleep" - - "dev.restate.sdktesting.tests.SleepWithFailures" - - "dev.restate.sdktesting.tests.State" - - "dev.restate.sdktesting.tests.UpgradeWithInFlightInvocation" - - "dev.restate.sdktesting.tests.UpgradeWithNewInvocation" - - "dev.restate.sdktesting.tests.UserErrors" - - "dev.restate.sdktesting.tests.WorkflowAPI" - "lazyState": - - "dev.restate.sdktesting.tests.State" + - "dev.restate.sdktesting.tests.AwaitTimeout" + - "dev.restate.sdktesting.tests.CallOrdering" + - "dev.restate.sdktesting.tests.CancelInvocation" + - "dev.restate.sdktesting.tests.Ingress" + - "dev.restate.sdktesting.tests.KillInvocation" + - "dev.restate.sdktesting.tests.ServiceToServiceCommunication" + - "dev.restate.sdktesting.tests.Sleep" + - "dev.restate.sdktesting.tests.SleepWithFailures" + - "dev.restate.sdktesting.tests.UpgradeWithInFlightInvocation" + - "dev.restate.sdktesting.tests.UpgradeWithNewInvocation" + - "dev.restate.sdktesting.tests.UserErrors" + - "dev.restate.sdktesting.tests.WorkflowAPI" "persistedTimers": - - "dev.restate.sdktesting.tests.ServiceToServiceCommunication" - - "dev.restate.sdktesting.tests.Sleep" + - "dev.restate.sdktesting.tests.Sleep" "singleThreadSinglePartition": - - "dev.restate.sdktesting.tests.AwaitTimeout" - - "dev.restate.sdktesting.tests.CallOrdering" - - "dev.restate.sdktesting.tests.CancelInvocation" - - "dev.restate.sdktesting.tests.Ingress" - - "dev.restate.sdktesting.tests.KafkaIngress" - - "dev.restate.sdktesting.tests.KillInvocation" - - "dev.restate.sdktesting.tests.PrivateService" - - "dev.restate.sdktesting.tests.ServiceToServiceCommunication" - - "dev.restate.sdktesting.tests.Sleep" - - "dev.restate.sdktesting.tests.SleepWithFailures" - - "dev.restate.sdktesting.tests.State" - - "dev.restate.sdktesting.tests.UpgradeWithInFlightInvocation" - - "dev.restate.sdktesting.tests.UpgradeWithNewInvocation" - - "dev.restate.sdktesting.tests.UserErrors" - - "dev.restate.sdktesting.tests.WorkflowAPI" + - "dev.restate.sdktesting.tests.AwaitTimeout" + - "dev.restate.sdktesting.tests.CallOrdering" + - "dev.restate.sdktesting.tests.CancelInvocation" + - "dev.restate.sdktesting.tests.Ingress" + - "dev.restate.sdktesting.tests.KillInvocation" + - "dev.restate.sdktesting.tests.ServiceToServiceCommunication" + - "dev.restate.sdktesting.tests.Sleep" + - "dev.restate.sdktesting.tests.SleepWithFailures" + - "dev.restate.sdktesting.tests.UpgradeWithInFlightInvocation" + - "dev.restate.sdktesting.tests.UpgradeWithNewInvocation" + - "dev.restate.sdktesting.tests.UserErrors" + - "dev.restate.sdktesting.tests.WorkflowAPI" diff --git a/test-services/src/main.rs b/test-services/src/main.rs index 3c69d9e..96ff8cb 100644 --- a/test-services/src/main.rs +++ b/test-services/src/main.rs @@ -1,4 +1,6 @@ mod counter; +mod map_object; +mod proxy; use restate_sdk::prelude::{Endpoint, HyperServer}; use std::env; @@ -14,6 +16,12 @@ async fn main() { if services == "*" || services.contains("Counter") { builder = builder.with_service(counter::Counter::serve(counter::CounterImpl)) } + if services == "*" || services.contains("Proxy") { + builder = builder.with_service(proxy::Proxy::serve(proxy::ProxyImpl)) + } + if services == "*" || services.contains("MapObject") { + builder = builder.with_service(map_object::MapObject::serve(map_object::MapObjectImpl)) + } HyperServer::new(builder.build()) .listen_and_serve(format!("0.0.0.0:{port}").parse().unwrap()) diff --git a/test-services/src/map_object.rs b/test-services/src/map_object.rs new file mode 100644 index 0000000..4bc4320 --- /dev/null +++ b/test-services/src/map_object.rs @@ -0,0 +1,55 @@ +use anyhow::anyhow; +use restate_sdk::prelude::*; +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub(crate) struct Entry { + key: String, + value: String, +} + +#[restate_sdk::object] +#[name = "MapObject"] +pub(crate) trait MapObject { + #[name = "set"] + async fn set(entry: Json) -> HandlerResult<()>; + #[name = "get"] + async fn get(key: String) -> HandlerResult; + #[name = "clearAll"] + async fn clear_all() -> HandlerResult>>; +} + +pub(crate) struct MapObjectImpl; + +impl MapObject for MapObjectImpl { + async fn set( + &self, + ctx: ObjectContext<'_>, + Json(Entry { key, value }): Json, + ) -> HandlerResult<()> { + ctx.set(&key, value); + Ok(()) + } + + async fn get(&self, ctx: ObjectContext<'_>, key: String) -> HandlerResult { + Ok(ctx.get(&key).await?.unwrap_or_default()) + } + + async fn clear_all(&self, ctx: ObjectContext<'_>) -> HandlerResult>> { + let keys = ctx.get_keys().await?; + + let mut entries = vec![]; + for k in keys { + let value = ctx + .get(&k) + .await? + .ok_or_else(|| HandlerError::from_anyhow(anyhow!("Missing key {k}")))?; + entries.push(Entry { key: k, value }) + } + + ctx.clear_all(); + + Ok(entries.into()) + } +} diff --git a/test-services/src/proxy.rs b/test-services/src/proxy.rs new file mode 100644 index 0000000..2fbe234 --- /dev/null +++ b/test-services/src/proxy.rs @@ -0,0 +1,107 @@ +use futures::future::BoxFuture; +use futures::FutureExt; +use restate_sdk::context::RequestTarget; +use restate_sdk::prelude::*; +use serde::{Deserialize, Serialize}; +use std::time::Duration; + +#[derive(Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub(crate) struct ProxyRequest { + service_name: String, + virtual_object_key: Option, + handler_name: String, + message: Vec, + delay_millis: Option, +} + +impl ProxyRequest { + fn to_target(&self) -> RequestTarget { + if let Some(key) = &self.virtual_object_key { + RequestTarget::Object { + name: self.service_name.clone(), + key: key.clone(), + handler: self.handler_name.clone(), + } + } else { + RequestTarget::Service { + name: self.service_name.clone(), + handler: self.handler_name.clone(), + } + } + } +} + +#[derive(Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub(crate) struct ManyCallRequest { + proxy_request: ProxyRequest, + one_way_call: bool, + await_at_the_end: bool, +} + +#[restate_sdk::service] +#[name = "Proxy"] +pub(crate) trait Proxy { + #[name = "call"] + async fn call(req: Json) -> HandlerResult>>; + #[name = "oneWayCall"] + async fn one_way_call(req: Json) -> HandlerResult<()>; + #[name = "manyCalls"] + async fn many_calls(req: Json>) -> HandlerResult<()>; +} + +pub(crate) struct ProxyImpl; + +impl Proxy for ProxyImpl { + async fn call( + &self, + ctx: Context<'_>, + Json(req): Json, + ) -> HandlerResult>> { + Ok(ctx.call(req.to_target(), req.message).await?) + } + + async fn one_way_call( + &self, + ctx: Context<'_>, + Json(req): Json, + ) -> HandlerResult<()> { + ctx.send( + req.to_target(), + req.message, + req.delay_millis.map(Duration::from_millis), + ); + Ok(()) + } + + async fn many_calls( + &self, + ctx: Context<'_>, + Json(requests): Json>, + ) -> HandlerResult<()> { + let mut futures: Vec, TerminalError>>> = vec![]; + + for req in requests { + if req.one_way_call { + ctx.send( + req.proxy_request.to_target(), + req.proxy_request.message, + req.proxy_request.delay_millis.map(Duration::from_millis), + ); + } else { + let fut = ctx + .call::<_, Vec>(req.proxy_request.to_target(), req.proxy_request.message); + if req.await_at_the_end { + futures.push(fut.boxed()) + } + } + } + + for fut in futures { + fut.await?; + } + + Ok(()) + } +}