diff --git a/server/src/frontend_api.rs b/server/src/frontend_api.rs index 6c9abb1e..dcd8c2e4 100644 --- a/server/src/frontend_api.rs +++ b/server/src/frontend_api.rs @@ -21,7 +21,7 @@ use unleash_types::{ use unleash_yggdrasil::{EngineState, ResolvedToggle}; use crate::error::EdgeError::ContextParseError; -use crate::types::{ClientIp, IncomingContext}; +use crate::types::{ClientIp, IncomingContext, PostContext}; use crate::{ error::{EdgeError, FrontendHydrationMissing}, metrics::client_metrics::MetricsCache, @@ -103,7 +103,7 @@ async fn post_proxy_all_features( edge_token: EdgeToken, engine_cache: Data>, token_cache: Data>, - context: Json, + context: Json, req: HttpRequest, ) -> EdgeJsonResult { post_all_features( @@ -184,7 +184,7 @@ async fn post_frontend_all_features( edge_token: EdgeToken, engine_cache: Data>, token_cache: Data>, - context: Json, + context: Json, req: HttpRequest, ) -> EdgeJsonResult { post_all_features( @@ -200,10 +200,10 @@ fn post_all_features( edge_token: EdgeToken, engine_cache: Data>, token_cache: Data>, - context: Json, + incoming_context: Json, client_ip: Option<&ClientIp>, ) -> EdgeJsonResult { - let context: Context = context.into_inner().into(); + let context: Context = incoming_context.into_inner().into(); let context_with_ip = if context.remote_address.is_none() { Context { remote_address: client_ip.map(|ip| ip.to_string()), @@ -343,7 +343,7 @@ async fn post_proxy_enabled_features( edge_token: EdgeToken, engine_cache: Data>, token_cache: Data>, - context: Json, + context: Json, req: HttpRequest, ) -> EdgeJsonResult { let client_ip = req.extensions().get::().cloned(); @@ -367,7 +367,7 @@ async fn post_frontend_enabled_features( edge_token: EdgeToken, engine_cache: Data>, token_cache: Data>, - context: Json, + context: Json, req: HttpRequest, ) -> EdgeJsonResult { let client_ip = req.extensions().get::().cloned(); @@ -392,7 +392,7 @@ security( pub async fn post_frontend_evaluate_single_feature( edge_token: EdgeToken, feature_name: Path, - context: Json, + context: Json, engine_cache: Data>, token_cache: Data>, req: HttpRequest, @@ -400,7 +400,7 @@ pub async fn post_frontend_evaluate_single_feature( evaluate_feature( edge_token, feature_name.into_inner(), - &context.into_inner(), + &context.into_inner().into(), token_cache, engine_cache, req.extensions().get::().cloned(), @@ -436,7 +436,7 @@ pub async fn get_frontend_evaluate_single_feature( evaluate_feature( edge_token, feature_name.into_inner(), - &context.into_inner(), + &context.into_inner().into(), token_cache, engine_cache, req.extensions().get::().cloned(), @@ -447,12 +447,12 @@ pub async fn get_frontend_evaluate_single_feature( pub fn evaluate_feature( edge_token: EdgeToken, feature_name: String, - incoming_context: &IncomingContext, + incoming_context: &Context, token_cache: Data>, engine_cache: Data>, client_ip: Option, ) -> EdgeResult { - let context: Context = incoming_context.clone().into(); + let context: Context = incoming_context.clone(); let context_with_ip = if context.remote_address.is_none() { Context { remote_address: client_ip.map(|ip| ip.to_string()), @@ -496,7 +496,7 @@ async fn post_enabled_features( edge_token: EdgeToken, engine_cache: Data>, token_cache: Data>, - context: Json, + context: Json, client_ip: Option, ) -> EdgeJsonResult { let context: Context = context.into_inner().into(); @@ -1547,7 +1547,8 @@ mod tests { .insert_header(("Authorization", auth_key.clone())) .set_json(json!({ "properties": {"companyId": "bricks"}})) .to_request(); - let result: FrontendResult = test::call_and_read_body_json(&app, req).await; + let result: FrontendResult = test::try_call_and_read_body_json(&app, req).await.expect("Failed to call endpoint"); + tracing::info!("{result:?}"); assert_eq!(result.toggles.len(), 1); } diff --git a/server/src/types.rs b/server/src/types.rs index 3160e9cf..fda4521b 100644 --- a/server/src/types.rs +++ b/server/src/types.rs @@ -53,6 +53,30 @@ impl From for Context { } } +#[derive(Deserialize, Serialize, Debug, Clone)] +#[serde(rename_all = "camelCase")] +pub struct PostContext { + pub context: Option, + #[serde(flatten)] + pub flattened_context: Option, + #[serde(flatten)] + pub extra_properties: HashMap, +} + +impl From for Context { + fn from(input: PostContext) -> Self { + if let Some(context) = input.context { + context + } else { + IncomingContext { + context: input.flattened_context.unwrap_or_default(), + extra_properties: input.extra_properties, + } + .into() + } + } +} + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, utoipa::ToSchema)] #[serde(rename_all = "lowercase")] pub enum TokenType { @@ -494,6 +518,7 @@ mod tests { use std::collections::HashMap; use std::str::FromStr; + use serde_json::json; use test_case::test_case; use tracing::warn; use unleash_types::client_features::Context; @@ -502,6 +527,8 @@ mod tests { use crate::http::unleash_client::EdgeTokens; use crate::types::{EdgeResult, EdgeToken, IncomingContext}; + use super::PostContext; + fn test_str(token: &str) -> EdgeToken { EdgeToken::from_str( &(token.to_owned() + ".614a75cf68bef8703aa1bd8304938a81ec871f86ea40c975468eabd6"), @@ -687,4 +714,97 @@ mod tests { let converted: Context = incoming_context.into(); assert_eq!(converted.properties, None); } + + #[test] + fn completely_flat_json_parses_to_a_context() { + let json = json!( + { + "userId": "7", + "flat": "endsUpInProps", + "invalidProperty": "alsoEndsUpInProps" + } + ); + + let post_context: PostContext = serde_json::from_value(json).unwrap(); + let parsed_context: Context = post_context.into(); + + assert_eq!(parsed_context.user_id, Some("7".into())); + assert_eq!( + parsed_context.properties, + Some(HashMap::from([ + ("flat".into(), "endsUpInProps".into()), + ("invalidProperty".into(), "alsoEndsUpInProps".into()) + ])) + ); + } + + #[test] + fn post_context_root_level_properties_are_ignored_if_context_property_is_set() { + let json = json!( + { + "context": { + "userId":"7", + }, + "invalidProperty": "thisNeverGoesAnywhere", + "anotherInvalidProperty": "alsoGoesNoWhere" + } + ); + + let post_context: PostContext = serde_json::from_value(json).unwrap(); + let parsed_context: Context = post_context.into(); + assert_eq!(parsed_context.properties, None); + + assert_eq!(parsed_context.user_id, Some("7".into())); + } + + #[test] + fn post_context_properties_are_taken_from_nested_context_object_but_root_levels_are_ignored() { + let json = json!( + { + "context": { + "userId":"7", + "properties": { + "nested": "nestedValue" + } + }, + "invalidProperty": "thisNeverGoesAnywhere" + } + ); + + let post_context: PostContext = serde_json::from_value(json).unwrap(); + let parsed_context: Context = post_context.into(); + assert_eq!( + parsed_context.properties, + Some(HashMap::from([("nested".into(), "nestedValue".into()),])) + ); + + assert_eq!(parsed_context.user_id, Some("7".into())); + } + + #[test] + fn post_context_properties_are_taken_from_nested_context_object_but_custom_properties_on_context_are_ignored( + ) { + let json = json!( + { + "context": { + "userId":"7", + "howDidYouGetHere": "I dunno bro", + "properties": { + "nested": "nestedValue" + } + }, + "flat": "endsUpInProps", + "invalidProperty": "thisNeverGoesAnywhere" + } + ); + + let post_context: PostContext = serde_json::from_value(json).unwrap(); + let parsed_context: Context = post_context.into(); + assert_eq!( + parsed_context.properties, + Some(HashMap::from([("nested".into(), "nestedValue".into()),])) + ); + + assert_eq!(parsed_context.user_id, Some("7".into())); + } }