From 76b2514a30077b6e64949c77dc7ced368613de06 Mon Sep 17 00:00:00 2001 From: Stefan Radig Date: Fri, 13 Dec 2024 15:47:16 +0100 Subject: [PATCH 01/16] First stab at adding VPC endpoint checks --- proxy/src/auth/backend/mod.rs | 55 +++++++++++++++++-- proxy/src/auth/mod.rs | 12 ++++ proxy/src/cache/project_info.rs | 14 +++++ proxy/src/context/mod.rs | 11 +++- .../control_plane/client/cplane_proxy_v1.rs | 26 +++++++-- proxy/src/control_plane/client/mock.rs | 9 ++- proxy/src/control_plane/client/mod.rs | 6 +- proxy/src/control_plane/messages.rs | 19 ++++++- proxy/src/control_plane/mod.rs | 6 +- proxy/src/proxy/tests/mod.rs | 4 +- proxy/src/serverless/backend.rs | 22 +++++++- 11 files changed, 160 insertions(+), 24 deletions(-) diff --git a/proxy/src/auth/backend/mod.rs b/proxy/src/auth/backend/mod.rs index d17d91a56d96..3768e003ae7c 100644 --- a/proxy/src/auth/backend/mod.rs +++ b/proxy/src/auth/backend/mod.rs @@ -26,10 +26,12 @@ use crate::context::RequestContext; use crate::control_plane::client::ControlPlaneClient; use crate::control_plane::errors::GetAuthInfoError; use crate::control_plane::{ - self, AuthSecret, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret, ControlPlaneApi, + self, AuthSecret, CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedNodeInfo, + CachedRoleSecret, ControlPlaneApi, }; use crate::intern::EndpointIdInt; use crate::metrics::Metrics; +use crate::protocol2::ConnectionInfoExtra; use crate::proxy::connect_compute::ComputeConnectBackend; use crate::proxy::NeonOptions; use crate::rate_limiter::{BucketRateLimiter, EndpointRateLimiter}; @@ -283,7 +285,8 @@ async fn auth_quirks( }; debug!("fetching user's authentication info"); - let (allowed_ips, maybe_secret) = api.get_allowed_ips_and_secret(ctx, &info).await?; + let (allowed_ips, allowed_vpc_endpoint_ids, maybe_secret) = + api.get_allowed_ips_and_secret(ctx, &info).await?; // check allowed list if config.ip_allowlist_check_enabled @@ -292,6 +295,24 @@ async fn auth_quirks( return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr())); } + // check if a VPC endpoint ID is coming in and if yes, if it's allowed + // TODO: Add flag to enable/disable VPC endpoint ID check + let extra = ctx.extra(); + let incoming_endpoint_id = match extra { + None => "".to_string(), + Some(ConnectionInfoExtra::Aws { vpce_id }) => { + // Convert the vcpe_id to a string + match String::from_utf8(vpce_id.to_vec()) { + Ok(s) => s, + Err(_e) => "".to_string(), + } + } + Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(), + }; + if incoming_endpoint_id != "" && !allowed_vpc_endpoint_ids.contains(&incoming_endpoint_id) { + return Err(AuthError::vpc_endpoint_id_not_allowed(incoming_endpoint_id)); + } + if !endpoint_rate_limiter.check(info.endpoint.clone().into(), 1) { return Err(AuthError::too_many_connections()); } @@ -443,12 +464,23 @@ impl Backend<'_, ComputeUserInfo> { pub(crate) async fn get_allowed_ips_and_secret( &self, ctx: &RequestContext, - ) -> Result<(CachedAllowedIps, Option), GetAuthInfoError> { + ) -> Result< + ( + CachedAllowedIps, + CachedAllowedVpcEndpointIds, + Option, + ), + GetAuthInfoError, + > { match self { Self::ControlPlane(api, user_info) => { api.get_allowed_ips_and_secret(ctx, user_info).await } - Self::Local(_) => Ok((Cached::new_uncached(Arc::new(vec![])), None)), + Self::Local(_) => Ok(( + Cached::new_uncached(Arc::new(vec![])), + Cached::new_uncached(Arc::new(vec![])), + None, + )), } } } @@ -514,7 +546,9 @@ mod tests { use crate::auth::{ComputeUserInfoMaybeEndpoint, IpPattern}; use crate::config::AuthenticationConfig; use crate::context::RequestContext; - use crate::control_plane::{self, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret}; + use crate::control_plane::{ + self, CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret, + }; use crate::proxy::NeonOptions; use crate::rate_limiter::{EndpointRateLimiter, RateBucketInfo}; use crate::scram::threadpool::ThreadPool; @@ -523,6 +557,7 @@ mod tests { struct Auth { ips: Vec, + vpc_endpoint_ids: Vec, secret: AuthSecret, } @@ -540,11 +575,16 @@ mod tests { _ctx: &RequestContext, _user_info: &super::ComputeUserInfo, ) -> Result< - (CachedAllowedIps, Option), + ( + CachedAllowedIps, + CachedAllowedVpcEndpointIds, + Option, + ), control_plane::errors::GetAuthInfoError, > { Ok(( CachedAllowedIps::new_uncached(Arc::new(self.ips.clone())), + CachedAllowedVpcEndpointIds::new_uncached(Arc::new(self.vpc_endpoint_ids.clone())), Some(CachedRoleSecret::new_uncached(Some(self.secret.clone()))), )) } @@ -642,6 +682,7 @@ mod tests { let ctx = RequestContext::test(); let api = Auth { ips: vec![], + vpc_endpoint_ids: vec![], secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()), }; @@ -722,6 +763,7 @@ mod tests { let ctx = RequestContext::test(); let api = Auth { ips: vec![], + vpc_endpoint_ids: vec![], secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()), }; @@ -774,6 +816,7 @@ mod tests { let ctx = RequestContext::test(); let api = Auth { ips: vec![], + vpc_endpoint_ids: vec![], secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()), }; diff --git a/proxy/src/auth/mod.rs b/proxy/src/auth/mod.rs index 0198cc306e08..987e5de25d56 100644 --- a/proxy/src/auth/mod.rs +++ b/proxy/src/auth/mod.rs @@ -69,6 +69,12 @@ pub(crate) enum AuthError { )] IpAddressNotAllowed(IpAddr), + #[error( + "This VPC endpoint id {0} is not allowed to connect to this endpoint. \ + Please add it to the allowed list in the Neon console." + )] + VpcEndpointIdNotAllowed(String), + #[error("Too many connections to this endpoint. Please try again later.")] TooManyConnections, @@ -95,6 +101,10 @@ impl AuthError { AuthError::IpAddressNotAllowed(ip) } + pub(crate) fn vpc_endpoint_id_not_allowed(id: String) -> Self { + AuthError::VpcEndpointIdNotAllowed(id) + } + pub(crate) fn too_many_connections() -> Self { AuthError::TooManyConnections } @@ -124,6 +134,7 @@ impl UserFacingError for AuthError { Self::MissingEndpointName => self.to_string(), Self::Io(_) => "Internal error".to_string(), Self::IpAddressNotAllowed(_) => self.to_string(), + Self::VpcEndpointIdNotAllowed(_) => self.to_string(), Self::TooManyConnections => self.to_string(), Self::UserTimeout(_) => self.to_string(), Self::ConfirmationTimeout(_) => self.to_string(), @@ -144,6 +155,7 @@ impl ReportableError for AuthError { Self::MissingEndpointName => crate::error::ErrorKind::User, Self::Io(_) => crate::error::ErrorKind::ClientDisconnect, Self::IpAddressNotAllowed(_) => crate::error::ErrorKind::User, + Self::VpcEndpointIdNotAllowed(_) => crate::error::ErrorKind::User, Self::TooManyConnections => crate::error::ErrorKind::RateLimit, Self::UserTimeout(_) => crate::error::ErrorKind::User, Self::ConfirmationTimeout(_) => crate::error::ErrorKind::User, diff --git a/proxy/src/cache/project_info.rs b/proxy/src/cache/project_info.rs index a5e71f1a8744..e2d49630ca6a 100644 --- a/proxy/src/cache/project_info.rs +++ b/proxy/src/cache/project_info.rs @@ -51,6 +51,7 @@ impl From for Entry { struct EndpointInfo { secret: std::collections::HashMap>>, allowed_ips: Option>>>, + allowed_vpc_endpoint_ids: Option>>>, } impl EndpointInfo { @@ -256,6 +257,19 @@ impl ProjectInfoCacheImpl { self.insert_project2endpoint(project_id, endpoint_id); self.cache.entry(endpoint_id).or_default().allowed_ips = Some(allowed_ips.into()); } + pub(crate) fn insert_allowed_vpc_endpoint_ids( + &self, + project_id: ProjectIdInt, + endpoint_id: EndpointIdInt, + allowed_vpc_endpoint_ids: Arc>, + ) { + if self.cache.len() >= self.config.size { + // If there are too many entries, wait until the next gc cycle. + return; + } + self.insert_project2endpoint(project_id, endpoint_id); + self.cache.entry(endpoint_id).or_default().allowed_vpc_endpoint_ids = Some(allowed_vpc_endpoint_ids.into()); + } fn insert_project2endpoint(&self, project_id: ProjectIdInt, endpoint_id: EndpointIdInt) { if let Some(mut endpoints) = self.project2ep.get_mut(&project_id) { endpoints.insert(endpoint_id); diff --git a/proxy/src/context/mod.rs b/proxy/src/context/mod.rs index a9fb513d3ceb..3236b2e1bfb0 100644 --- a/proxy/src/context/mod.rs +++ b/proxy/src/context/mod.rs @@ -19,7 +19,7 @@ use crate::intern::{BranchIdInt, ProjectIdInt}; use crate::metrics::{ ConnectOutcome, InvalidEndpointsGroup, LatencyTimer, Metrics, Protocol, Waiting, }; -use crate::protocol2::ConnectionInfo; +use crate::protocol2::{ConnectionInfo, ConnectionInfoExtra}; use crate::types::{DbName, EndpointId, RoleName}; pub mod parquet; @@ -312,6 +312,15 @@ impl RequestContext { .ip() } + pub(crate) fn extra(&self) -> Option { + self.0 + .try_lock() + .expect("should not deadlock") + .conn_info + .extra + .clone() + } + pub(crate) fn cold_start_info(&self) -> ColdStartInfo { self.0 .try_lock() diff --git a/proxy/src/control_plane/client/cplane_proxy_v1.rs b/proxy/src/control_plane/client/cplane_proxy_v1.rs index ece03156d1fa..39a4100b97dd 100644 --- a/proxy/src/control_plane/client/cplane_proxy_v1.rs +++ b/proxy/src/control_plane/client/cplane_proxy_v1.rs @@ -22,7 +22,7 @@ use crate::control_plane::errors::{ use crate::control_plane::locks::ApiLocks; use crate::control_plane::messages::{ColdStartInfo, EndpointJwksResponse, Reason}; use crate::control_plane::{ - AuthInfo, AuthSecret, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret, NodeInfo, + AuthInfo, AuthSecret, CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret, NodeInfo, }; use crate::metrics::{CacheOutcome, Metrics}; use crate::rate_limiter::WakeComputeRateLimiter; @@ -137,9 +137,6 @@ impl NeonControlPlaneClient { } }; - // Ivan: don't know where it will be used, so I leave it here - let _endpoint_vpc_ids = body.allowed_vpc_endpoint_ids.unwrap_or_default(); - let secret = if body.role_secret.is_empty() { None } else { @@ -153,9 +150,12 @@ impl NeonControlPlaneClient { .proxy .allowed_ips_number .observe(allowed_ips.len() as f64); + let allowed_vpc_endpoint_ids = body.allowed_vpc_endpoint_ids.unwrap_or_default(); + // TODO: Add metrics? Ok(AuthInfo { secret, allowed_ips, + allowed_vpc_endpoint_ids, project_id: body.project_id, }) } @@ -312,6 +312,11 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { normalized_ep_int, Arc::new(auth_info.allowed_ips), ); + self.caches.project_info.insert_allowed_vpc_endpoint_ids( + project_id, + normalized_ep_int, + Arc::new(auth_info.allowed_vpc_endpoint_ids), + ); ctx.set_project_id(project_id); } // When we just got a secret, we don't need to invalidate it. @@ -322,14 +327,16 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { &self, ctx: &RequestContext, user_info: &ComputeUserInfo, - ) -> Result<(CachedAllowedIps, Option), GetAuthInfoError> { + ) -> Result<(CachedAllowedIps, CachedAllowedVpcEndpointIds, Option), GetAuthInfoError> { let normalized_ep = &user_info.endpoint.normalize(); + let allowed_vcp_endpoint_ids = Arc::new(Vec::new()); if let Some(allowed_ips) = self.caches.project_info.get_allowed_ips(normalized_ep) { Metrics::get() .proxy .allowed_ips_cache_misses .inc(CacheOutcome::Hit); - return Ok((allowed_ips, None)); + // TODO + return Ok((allowed_ips, Cached::new_uncached(allowed_vcp_endpoint_ids), None)); } Metrics::get() .proxy @@ -337,6 +344,7 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { .inc(CacheOutcome::Miss); let auth_info = self.do_get_auth_info(ctx, user_info).await?; let allowed_ips = Arc::new(auth_info.allowed_ips); + let allowed_vcp_endpoint_ids = Arc::new(auth_info.allowed_vpc_endpoint_ids); let user = &user_info.user; if let Some(project_id) = auth_info.project_id { let normalized_ep_int = normalized_ep.into(); @@ -351,10 +359,16 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { normalized_ep_int, allowed_ips.clone(), ); + self.caches.project_info.insert_allowed_vpc_endpoint_ids( + project_id, + normalized_ep_int, + allowed_vcp_endpoint_ids.clone(), + ); ctx.set_project_id(project_id); } Ok(( Cached::new_uncached(allowed_ips), + Cached::new_uncached(allowed_vcp_endpoint_ids), Some(Cached::new_uncached(auth_info.secret)), )) } diff --git a/proxy/src/control_plane/client/mock.rs b/proxy/src/control_plane/client/mock.rs index 5f8bda0f35ae..5a1ba972f093 100644 --- a/proxy/src/control_plane/client/mock.rs +++ b/proxy/src/control_plane/client/mock.rs @@ -13,7 +13,7 @@ use crate::auth::backend::ComputeUserInfo; use crate::auth::IpPattern; use crate::cache::Cached; use crate::context::RequestContext; -use crate::control_plane::client::{CachedAllowedIps, CachedRoleSecret}; +use crate::control_plane::client::{CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedRoleSecret}; use crate::control_plane::errors::{ ControlPlaneError, GetAuthInfoError, GetEndpointJwksError, WakeComputeError, }; @@ -121,6 +121,8 @@ impl MockControlPlane { Ok(AuthInfo { secret, allowed_ips, + // TODO + allowed_vpc_endpoint_ids: vec![], project_id: None, }) } @@ -218,11 +220,14 @@ impl super::ControlPlaneApi for MockControlPlane { &self, _ctx: &RequestContext, user_info: &ComputeUserInfo, - ) -> Result<(CachedAllowedIps, Option), GetAuthInfoError> { + ) -> Result<(CachedAllowedIps, CachedAllowedVpcEndpointIds, Option), GetAuthInfoError> { Ok(( Cached::new_uncached(Arc::new( self.do_get_auth_info(user_info).await?.allowed_ips, )), + Cached::new_uncached(Arc::new( + self.do_get_auth_info(user_info).await?.allowed_vpc_endpoint_ids, + )), None, )) } diff --git a/proxy/src/control_plane/client/mod.rs b/proxy/src/control_plane/client/mod.rs index b879f3a59ff4..aec8f06345fd 100644 --- a/proxy/src/control_plane/client/mod.rs +++ b/proxy/src/control_plane/client/mod.rs @@ -17,7 +17,7 @@ use crate::cache::project_info::ProjectInfoCacheImpl; use crate::config::{CacheOptions, EndpointCacheConfig, ProjectInfoCacheOptions}; use crate::context::RequestContext; use crate::control_plane::{ - errors, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret, ControlPlaneApi, NodeInfoCache, + errors, CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret, ControlPlaneApi, NodeInfoCache, }; use crate::error::ReportableError; use crate::metrics::ApiLockMetrics; @@ -59,7 +59,7 @@ impl ControlPlaneApi for ControlPlaneClient { &self, ctx: &RequestContext, user_info: &ComputeUserInfo, - ) -> Result<(CachedAllowedIps, Option), errors::GetAuthInfoError> { + ) -> Result<(CachedAllowedIps, CachedAllowedVpcEndpointIds, Option), errors::GetAuthInfoError> { match self { Self::ProxyV1(api) => api.get_allowed_ips_and_secret(ctx, user_info).await, #[cfg(any(test, feature = "testing"))] @@ -104,7 +104,7 @@ pub(crate) trait TestControlPlaneClient: Send + Sync + 'static { fn get_allowed_ips_and_secret( &self, - ) -> Result<(CachedAllowedIps, Option), errors::GetAuthInfoError>; + ) -> Result<(CachedAllowedIps, CachedAllowedVpcEndpointIds, Option), errors::GetAuthInfoError>; fn dyn_clone(&self) -> Box; } diff --git a/proxy/src/control_plane/messages.rs b/proxy/src/control_plane/messages.rs index d068614b24df..0ae3cd70fca2 100644 --- a/proxy/src/control_plane/messages.rs +++ b/proxy/src/control_plane/messages.rs @@ -227,8 +227,8 @@ pub(crate) struct UserFacingMessage { pub(crate) struct GetEndpointAccessControl { pub(crate) role_secret: Box, pub(crate) allowed_ips: Option>, + pub(crate) allowed_vpc_endpoint_ids: Option>, pub(crate) project_id: Option, - pub(crate) allowed_vpc_endpoint_ids: Option>, } /// Response which holds compute node's `host:port` pair. @@ -282,6 +282,8 @@ pub(crate) struct DatabaseInfo { pub(crate) aux: MetricsAuxInfo, #[serde(default)] pub(crate) allowed_ips: Option>, + #[serde(default)] + pub(crate) allowed_vpc_endpoint_ids: Option>, } // Manually implement debug to omit sensitive info. @@ -293,6 +295,7 @@ impl fmt::Debug for DatabaseInfo { .field("dbname", &self.dbname) .field("user", &self.user) .field("allowed_ips", &self.allowed_ips) + .field("allowed_vpc_endpoint_ids", &self.allowed_vpc_endpoint_ids) .finish_non_exhaustive() } } @@ -457,7 +460,7 @@ mod tests { #[test] fn parse_get_role_secret() -> anyhow::Result<()> { - // Empty `allowed_ips` field. + // Empty `allowed_ips` and `allowed_vcp_endpoint_ids` field. let json = json!({ "role_secret": "secret", }); @@ -467,9 +470,21 @@ mod tests { "allowed_ips": ["8.8.8.8"], }); serde_json::from_str::(&json.to_string())?; + let json = json!({ + "role_secret": "secret", + "allowed_vpc_endpoint_ids": ["vpce-0abcd1234567890ef"], + }); + serde_json::from_str::(&json.to_string())?; + let json = json!({ + "role_secret": "secret", + "allowed_ips": ["8.8.8.8"], + "allowed_vpc_endpoint_ids": ["vpce-0abcd1234567890ef"], + }); + serde_json::from_str::(&json.to_string())?; let json = json!({ "role_secret": "secret", "allowed_ips": ["8.8.8.8"], + "allowed_vpc_endpoint_ids": ["vpce-0abcd1234567890ef"], "project_id": "project", }); serde_json::from_str::(&json.to_string())?; diff --git a/proxy/src/control_plane/mod.rs b/proxy/src/control_plane/mod.rs index 1dca26d6866c..60db0ac2bd6f 100644 --- a/proxy/src/control_plane/mod.rs +++ b/proxy/src/control_plane/mod.rs @@ -52,6 +52,8 @@ pub(crate) struct AuthInfo { pub(crate) secret: Option, /// List of IP addresses allowed for the autorization. pub(crate) allowed_ips: Vec, + /// List of VPC endpoints allowed for the autorization. + pub(crate) allowed_vpc_endpoint_ids: Vec, /// Project ID. This is used for cache invalidation. pub(crate) project_id: Option, } @@ -100,6 +102,7 @@ pub(crate) type NodeInfoCache = pub(crate) type CachedNodeInfo = Cached<&'static NodeInfoCache, NodeInfo>; pub(crate) type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option>; pub(crate) type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc>>; +pub(crate) type CachedAllowedVpcEndpointIds = Cached<&'static ProjectInfoCacheImpl, Arc>>; /// This will allocate per each call, but the http requests alone /// already require a few allocations, so it should be fine. @@ -113,11 +116,12 @@ pub(crate) trait ControlPlaneApi { user_info: &ComputeUserInfo, ) -> Result; + // TODO: Should we rename this one? It does more... async fn get_allowed_ips_and_secret( &self, ctx: &RequestContext, user_info: &ComputeUserInfo, - ) -> Result<(CachedAllowedIps, Option), errors::GetAuthInfoError>; + ) -> Result<(CachedAllowedIps, CachedAllowedVpcEndpointIds, Option), errors::GetAuthInfoError>; async fn get_endpoint_jwks( &self, diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index 10db2bcb303f..769040cf0e82 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -26,7 +26,7 @@ use crate::config::{ComputeConfig, RetryConfig}; use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient}; use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status}; use crate::control_plane::{ - self, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret, NodeInfo, NodeInfoCache, + self, CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret, NodeInfo, NodeInfoCache, }; use crate::error::ErrorKind; use crate::tls::client_config::compute_client_config_with_certs; @@ -528,7 +528,7 @@ impl TestControlPlaneClient for TestConnectMechanism { fn get_allowed_ips_and_secret( &self, - ) -> Result<(CachedAllowedIps, Option), control_plane::errors::GetAuthInfoError> + ) -> Result<(CachedAllowedIps, CachedAllowedVpcEndpointIds, Option), control_plane::errors::GetAuthInfoError> { unimplemented!("not used in tests") } diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 6d5fb13681e9..051bdef4bcea 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -30,6 +30,7 @@ use crate::control_plane::locks::ApiLocks; use crate::control_plane::CachedNodeInfo; use crate::error::{ErrorKind, ReportableError, UserFacingError}; use crate::intern::EndpointIdInt; +use crate::protocol2::ConnectionInfoExtra; use crate::proxy::connect_compute::ConnectMechanism; use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute}; use crate::rate_limiter::EndpointRateLimiter; @@ -57,7 +58,26 @@ impl PoolingBackend { let user_info = user_info.clone(); let backend = self.auth_backend.as_ref().map(|()| user_info.clone()); - let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?; + let (allowed_ips, allowed_vpce_ids, maybe_secret) = + backend.get_allowed_ips_and_secret(ctx).await?; + + let extra = ctx.extra(); + let incoming_endpoint_id = match extra { + None => "".to_string(), + Some(ConnectionInfoExtra::Aws { vpce_id }) => { + // Convert the vcpe_id to a string + match String::from_utf8(vpce_id.to_vec()) { + Ok(s) => s, + Err(_e) => "".to_string(), + } + } + Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(), + }; + + if incoming_endpoint_id != "" && !allowed_vpce_ids.contains(&incoming_endpoint_id) { + return Err(AuthError::vpc_endpoint_id_not_allowed(incoming_endpoint_id)); + } + if self.config.authentication_config.ip_allowlist_check_enabled && !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) { From 5d8d5d19c97d182f0542b03e73b936859576f9fb Mon Sep 17 00:00:00 2001 From: Stefan Radig Date: Fri, 13 Dec 2024 16:10:55 +0100 Subject: [PATCH 02/16] Added some more wiring of the cache --- proxy/src/cache/project_info.rs | 48 +++++++++++++++++++ .../control_plane/client/cplane_proxy_v1.rs | 15 +++--- 2 files changed, 56 insertions(+), 7 deletions(-) diff --git a/proxy/src/cache/project_info.rs b/proxy/src/cache/project_info.rs index e2d49630ca6a..19a7bd0a4abc 100644 --- a/proxy/src/cache/project_info.rs +++ b/proxy/src/cache/project_info.rs @@ -93,9 +93,27 @@ impl EndpointInfo { } None } + pub(crate) fn get_allowed_vpc_endpoint_ids( + &self, + valid_since: Instant, + ignore_cache_since: Option, + ) -> Option<(Arc>, bool)> { + if let Some(allowed_vpc_endpoint_ids) = &self.allowed_vpc_endpoint_ids { + if valid_since < allowed_vpc_endpoint_ids.created_at { + return Some(( + allowed_vpc_endpoint_ids.value.clone(), + Self::check_ignore_cache(ignore_cache_since, allowed_vpc_endpoint_ids.created_at), + )); + } + } + None + } pub(crate) fn invalidate_allowed_ips(&mut self) { self.allowed_ips = None; } + pub(crate) fn invalidate_allowed_vpc_endpoint_ids(&mut self) { + self.allowed_vpc_endpoint_ids = None; + } pub(crate) fn invalidate_role_secret(&mut self, role_name: RoleNameInt) { self.secret.remove(&role_name); } @@ -227,6 +245,24 @@ impl ProjectInfoCacheImpl { } Some(Cached::new_uncached(value)) } + pub(crate) fn get_allowed_vpc_endpoint_ids( + &self, + endpoint_id: &EndpointId, + ) -> Option>>> { + let endpoint_id = EndpointIdInt::get(endpoint_id)?; + let (valid_since, ignore_cache_since) = self.get_cache_times(); + let endpoint_info = self.cache.get(&endpoint_id)?; + let value = endpoint_info.get_allowed_vpc_endpoint_ids(valid_since, ignore_cache_since); + let (value, ignore_cache) = value?; + if !ignore_cache { + let cached = Cached { + token: Some((self, CachedLookupInfo::new_allowed_vpc_endpoint_ids(endpoint_id))), + value, + }; + return Some(cached); + } + Some(Cached::new_uncached(value)) + } pub(crate) fn insert_role_secret( &self, project_id: ProjectIdInt, @@ -348,11 +384,18 @@ impl CachedLookupInfo { lookup_type: LookupType::AllowedIps, } } + pub(self) fn new_allowed_vpc_endpoint_ids(endpoint_id: EndpointIdInt) -> Self { + Self { + endpoint_id, + lookup_type: LookupType::AllowedVpcEndpointIds, + } + } } enum LookupType { RoleSecret(RoleNameInt), AllowedIps, + AllowedVpcEndpointIds } impl Cache for ProjectInfoCacheImpl { @@ -374,6 +417,11 @@ impl Cache for ProjectInfoCacheImpl { endpoint_info.invalidate_allowed_ips(); } } + LookupType::AllowedVpcEndpointIds => { + if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) { + endpoint_info.invalidate_allowed_vpc_endpoint_ids(); + } + } } } } diff --git a/proxy/src/control_plane/client/cplane_proxy_v1.rs b/proxy/src/control_plane/client/cplane_proxy_v1.rs index 39a4100b97dd..17d3541b656d 100644 --- a/proxy/src/control_plane/client/cplane_proxy_v1.rs +++ b/proxy/src/control_plane/client/cplane_proxy_v1.rs @@ -329,14 +329,15 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { user_info: &ComputeUserInfo, ) -> Result<(CachedAllowedIps, CachedAllowedVpcEndpointIds, Option), GetAuthInfoError> { let normalized_ep = &user_info.endpoint.normalize(); - let allowed_vcp_endpoint_ids = Arc::new(Vec::new()); if let Some(allowed_ips) = self.caches.project_info.get_allowed_ips(normalized_ep) { - Metrics::get() - .proxy - .allowed_ips_cache_misses - .inc(CacheOutcome::Hit); - // TODO - return Ok((allowed_ips, Cached::new_uncached(allowed_vcp_endpoint_ids), None)); + if let Some(allowed_vcp_endpoint_ids) = self.caches.project_info.get_allowed_vpc_endpoint_ids(normalized_ep) { + Metrics::get() + .proxy + .allowed_ips_cache_misses // TODO SR: Should we rename this variable to something like allowed_ip_cache_stats? + .inc(CacheOutcome::Hit); + // TODO SR: This I don't understand this. Why are we returning an empty secret here? + return Ok((allowed_ips, allowed_vcp_endpoint_ids, None)); + } } Metrics::get() .proxy From 60305aa78311f9d257ec10fe0717b7f1f5bd09c3 Mon Sep 17 00:00:00 2001 From: Stefan Radig Date: Wed, 18 Dec 2024 09:47:24 +0100 Subject: [PATCH 03/16] Split get_allowed_ips_and_secret into 3 functions --- proxy/src/auth/backend/mod.rs | 73 ++++++++++--------- .../control_plane/client/cplane_proxy_v1.rs | 69 ++++++++++++++---- proxy/src/control_plane/client/mock.rs | 20 +++-- proxy/src/control_plane/client/mod.rs | 32 ++++++-- proxy/src/control_plane/messages.rs | 4 +- proxy/src/control_plane/mod.rs | 11 ++- proxy/src/proxy/tests/mod.rs | 13 +++- proxy/src/serverless/backend.rs | 20 +++-- 8 files changed, 159 insertions(+), 83 deletions(-) diff --git a/proxy/src/auth/backend/mod.rs b/proxy/src/auth/backend/mod.rs index 3768e003ae7c..7eb7688f8a29 100644 --- a/proxy/src/auth/backend/mod.rs +++ b/proxy/src/auth/backend/mod.rs @@ -284,9 +284,8 @@ async fn auth_quirks( Ok(info) => (info, None), }; - debug!("fetching user's authentication info"); - let (allowed_ips, allowed_vpc_endpoint_ids, maybe_secret) = - api.get_allowed_ips_and_secret(ctx, &info).await?; + debug!("fetching authentication info and allowlists"); + let allowed_ips = api.get_allowed_ips(ctx, &info).await?; // check allowed list if config.ip_allowlist_check_enabled @@ -309,17 +308,17 @@ async fn auth_quirks( } Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(), }; - if incoming_endpoint_id != "" && !allowed_vpc_endpoint_ids.contains(&incoming_endpoint_id) { - return Err(AuthError::vpc_endpoint_id_not_allowed(incoming_endpoint_id)); + if incoming_endpoint_id != "" { + let allowed_vpc_endpoint_ids = api.get_allowed_vpc_endpoint_ids(ctx, &info).await?; + if !allowed_vpc_endpoint_ids.contains(&incoming_endpoint_id) { + return Err(AuthError::vpc_endpoint_id_not_allowed(incoming_endpoint_id)); + } } if !endpoint_rate_limiter.check(info.endpoint.clone().into(), 1) { return Err(AuthError::too_many_connections()); } - let cached_secret = match maybe_secret { - Some(secret) => secret, - None => api.get_role_secret(ctx, &info).await?, - }; + let cached_secret = api.get_role_secret(ctx, &info).await?; let (cached_entry, secret) = cached_secret.take_value(); let secret = if let Some(secret) = secret { @@ -461,26 +460,31 @@ impl Backend<'_, ComputeUserInfo> { } } - pub(crate) async fn get_allowed_ips_and_secret( + pub(crate) async fn get_allowed_ips( &self, ctx: &RequestContext, - ) -> Result< - ( - CachedAllowedIps, - CachedAllowedVpcEndpointIds, - Option, - ), - GetAuthInfoError, - > { + ) -> Result { match self { Self::ControlPlane(api, user_info) => { - api.get_allowed_ips_and_secret(ctx, user_info).await + api.get_allowed_ips(ctx, user_info).await } - Self::Local(_) => Ok(( + Self::Local(_) => Ok( Cached::new_uncached(Arc::new(vec![])), + ), + } + } + + pub(crate) async fn get_allowed_vpc_endpoint_ids( + &self, + ctx: &RequestContext, + ) -> Result { + match self { + Self::ControlPlane(api, user_info) => { + api.get_allowed_vpc_endpoint_ids(ctx, user_info).await + } + Self::Local(_) => Ok( Cached::new_uncached(Arc::new(vec![])), - None, - )), + ), } } } @@ -570,23 +574,20 @@ mod tests { Ok(CachedRoleSecret::new_uncached(Some(self.secret.clone()))) } - async fn get_allowed_ips_and_secret( + async fn get_allowed_ips( &self, _ctx: &RequestContext, _user_info: &super::ComputeUserInfo, - ) -> Result< - ( - CachedAllowedIps, - CachedAllowedVpcEndpointIds, - Option, - ), - control_plane::errors::GetAuthInfoError, - > { - Ok(( - CachedAllowedIps::new_uncached(Arc::new(self.ips.clone())), - CachedAllowedVpcEndpointIds::new_uncached(Arc::new(self.vpc_endpoint_ids.clone())), - Some(CachedRoleSecret::new_uncached(Some(self.secret.clone()))), - )) + ) -> Result { + Ok(CachedAllowedIps::new_uncached(Arc::new(self.ips.clone()))) + } + + async fn get_allowed_vpc_endpoint_ids( + &self, + _ctx: &RequestContext, + _user_info: &super::ComputeUserInfo, + ) -> Result { + Ok(CachedAllowedVpcEndpointIds::new_uncached(Arc::new(self.vpc_endpoint_ids.clone()))) } async fn get_endpoint_jwks( diff --git a/proxy/src/control_plane/client/cplane_proxy_v1.rs b/proxy/src/control_plane/client/cplane_proxy_v1.rs index 17d3541b656d..fad2ec21fd76 100644 --- a/proxy/src/control_plane/client/cplane_proxy_v1.rs +++ b/proxy/src/control_plane/client/cplane_proxy_v1.rs @@ -323,21 +323,18 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { Ok(Cached::new_uncached(auth_info.secret)) } - async fn get_allowed_ips_and_secret( + async fn get_allowed_ips( &self, ctx: &RequestContext, user_info: &ComputeUserInfo, - ) -> Result<(CachedAllowedIps, CachedAllowedVpcEndpointIds, Option), GetAuthInfoError> { + ) -> Result { let normalized_ep = &user_info.endpoint.normalize(); if let Some(allowed_ips) = self.caches.project_info.get_allowed_ips(normalized_ep) { - if let Some(allowed_vcp_endpoint_ids) = self.caches.project_info.get_allowed_vpc_endpoint_ids(normalized_ep) { - Metrics::get() - .proxy - .allowed_ips_cache_misses // TODO SR: Should we rename this variable to something like allowed_ip_cache_stats? - .inc(CacheOutcome::Hit); - // TODO SR: This I don't understand this. Why are we returning an empty secret here? - return Ok((allowed_ips, allowed_vcp_endpoint_ids, None)); - } + Metrics::get() + .proxy + .allowed_ips_cache_misses // TODO SR: Should we rename this variable to something like allowed_ip_cache_stats? + .inc(CacheOutcome::Hit); + return Ok(allowed_ips); } Metrics::get() .proxy @@ -367,11 +364,53 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { ); ctx.set_project_id(project_id); } - Ok(( - Cached::new_uncached(allowed_ips), - Cached::new_uncached(allowed_vcp_endpoint_ids), - Some(Cached::new_uncached(auth_info.secret)), - )) + Ok(Cached::new_uncached(allowed_ips)) + } + + async fn get_allowed_vpc_endpoint_ids( + &self, + ctx: &RequestContext, + user_info: &ComputeUserInfo, + ) -> Result { + let normalized_ep = &user_info.endpoint.normalize(); + if let Some(allowed_vcp_endpoint_ids) = self.caches.project_info.get_allowed_vpc_endpoint_ids(normalized_ep) { + Metrics::get() + .proxy + .allowed_ips_cache_misses // TODO Replace with a dedicated variable + .inc(CacheOutcome::Hit); + return Ok(allowed_vcp_endpoint_ids); + } + + Metrics::get() + .proxy + .allowed_ips_cache_misses // TODO Replace with a dedicated variable + .inc(CacheOutcome::Miss); + + let auth_info = self.do_get_auth_info(ctx, user_info).await?; + let allowed_ips = Arc::new(auth_info.allowed_ips); + let allowed_vcp_endpoint_ids = Arc::new(auth_info.allowed_vpc_endpoint_ids); + let user = &user_info.user; + if let Some(project_id) = auth_info.project_id { + let normalized_ep_int = normalized_ep.into(); + self.caches.project_info.insert_role_secret( + project_id, + normalized_ep_int, + user.into(), + auth_info.secret.clone(), + ); + self.caches.project_info.insert_allowed_ips( + project_id, + normalized_ep_int, + allowed_ips.clone(), + ); + self.caches.project_info.insert_allowed_vpc_endpoint_ids( + project_id, + normalized_ep_int, + allowed_vcp_endpoint_ids.clone(), + ); + ctx.set_project_id(project_id); + } + Ok(Cached::new_uncached(allowed_vcp_endpoint_ids)) } #[tracing::instrument(skip_all)] diff --git a/proxy/src/control_plane/client/mock.rs b/proxy/src/control_plane/client/mock.rs index 5a1ba972f093..40c60513973d 100644 --- a/proxy/src/control_plane/client/mock.rs +++ b/proxy/src/control_plane/client/mock.rs @@ -216,20 +216,28 @@ impl super::ControlPlaneApi for MockControlPlane { )) } - async fn get_allowed_ips_and_secret( + async fn get_allowed_ips( &self, _ctx: &RequestContext, user_info: &ComputeUserInfo, - ) -> Result<(CachedAllowedIps, CachedAllowedVpcEndpointIds, Option), GetAuthInfoError> { - Ok(( + ) -> Result { + Ok( Cached::new_uncached(Arc::new( self.do_get_auth_info(user_info).await?.allowed_ips, - )), + )) + ) + } + + async fn get_allowed_vpc_endpoint_ids( + &self, + _ctx: &RequestContext, + user_info: &ComputeUserInfo, + ) -> Result { + Ok( Cached::new_uncached(Arc::new( self.do_get_auth_info(user_info).await?.allowed_vpc_endpoint_ids, )), - None, - )) + ) } async fn get_endpoint_jwks( diff --git a/proxy/src/control_plane/client/mod.rs b/proxy/src/control_plane/client/mod.rs index aec8f06345fd..47a95678a715 100644 --- a/proxy/src/control_plane/client/mod.rs +++ b/proxy/src/control_plane/client/mod.rs @@ -55,17 +55,31 @@ impl ControlPlaneApi for ControlPlaneClient { } } - async fn get_allowed_ips_and_secret( + async fn get_allowed_ips( &self, ctx: &RequestContext, user_info: &ComputeUserInfo, - ) -> Result<(CachedAllowedIps, CachedAllowedVpcEndpointIds, Option), errors::GetAuthInfoError> { + ) -> Result { match self { - Self::ProxyV1(api) => api.get_allowed_ips_and_secret(ctx, user_info).await, + Self::ProxyV1(api) => api.get_allowed_ips(ctx, user_info).await, #[cfg(any(test, feature = "testing"))] - Self::PostgresMock(api) => api.get_allowed_ips_and_secret(ctx, user_info).await, + Self::PostgresMock(api) => api.get_allowed_ips(ctx, user_info).await, #[cfg(test)] - Self::Test(api) => api.get_allowed_ips_and_secret(), + Self::Test(api) => api.get_allowed_ips(), + } + } + + async fn get_allowed_vpc_endpoint_ids( + &self, + ctx: &RequestContext, + user_info: &ComputeUserInfo, + ) -> Result { + match self { + Self::ProxyV1(api) => api.get_allowed_vpc_endpoint_ids(ctx, user_info).await, + #[cfg(any(test, feature = "testing"))] + Self::PostgresMock(api) => api.get_allowed_vpc_endpoint_ids(ctx, user_info).await, + #[cfg(test)] + Self::Test(api) => api.get_allowed_vpc_endpoint_ids(), } } @@ -102,9 +116,13 @@ impl ControlPlaneApi for ControlPlaneClient { pub(crate) trait TestControlPlaneClient: Send + Sync + 'static { fn wake_compute(&self) -> Result; - fn get_allowed_ips_and_secret( + fn get_allowed_ips( + &self, + ) -> Result; + + fn get_allowed_vpc_endpoint_ids( &self, - ) -> Result<(CachedAllowedIps, CachedAllowedVpcEndpointIds, Option), errors::GetAuthInfoError>; + ) -> Result; fn dyn_clone(&self) -> Box; } diff --git a/proxy/src/control_plane/messages.rs b/proxy/src/control_plane/messages.rs index 0ae3cd70fca2..a863072d9ab5 100644 --- a/proxy/src/control_plane/messages.rs +++ b/proxy/src/control_plane/messages.rs @@ -474,13 +474,13 @@ mod tests { "role_secret": "secret", "allowed_vpc_endpoint_ids": ["vpce-0abcd1234567890ef"], }); - serde_json::from_str::(&json.to_string())?; + serde_json::from_str::(&json.to_string())?; let json = json!({ "role_secret": "secret", "allowed_ips": ["8.8.8.8"], "allowed_vpc_endpoint_ids": ["vpce-0abcd1234567890ef"], }); - serde_json::from_str::(&json.to_string())?; + serde_json::from_str::(&json.to_string())?; let json = json!({ "role_secret": "secret", "allowed_ips": ["8.8.8.8"], diff --git a/proxy/src/control_plane/mod.rs b/proxy/src/control_plane/mod.rs index 60db0ac2bd6f..90de586ff53f 100644 --- a/proxy/src/control_plane/mod.rs +++ b/proxy/src/control_plane/mod.rs @@ -116,12 +116,17 @@ pub(crate) trait ControlPlaneApi { user_info: &ComputeUserInfo, ) -> Result; - // TODO: Should we rename this one? It does more... - async fn get_allowed_ips_and_secret( + async fn get_allowed_ips( &self, ctx: &RequestContext, user_info: &ComputeUserInfo, - ) -> Result<(CachedAllowedIps, CachedAllowedVpcEndpointIds, Option), errors::GetAuthInfoError>; + ) -> Result; + + async fn get_allowed_vpc_endpoint_ids( + &self, + ctx: &RequestContext, + user_info: &ComputeUserInfo, + ) -> Result; async fn get_endpoint_jwks( &self, diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index 769040cf0e82..8b00fa8586e3 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -26,7 +26,7 @@ use crate::config::{ComputeConfig, RetryConfig}; use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient}; use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status}; use crate::control_plane::{ - self, CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret, NodeInfo, NodeInfoCache, + self, CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedNodeInfo, NodeInfo, NodeInfoCache, }; use crate::error::ErrorKind; use crate::tls::client_config::compute_client_config_with_certs; @@ -526,13 +526,20 @@ impl TestControlPlaneClient for TestConnectMechanism { } } - fn get_allowed_ips_and_secret( + fn get_allowed_ips( &self, - ) -> Result<(CachedAllowedIps, CachedAllowedVpcEndpointIds, Option), control_plane::errors::GetAuthInfoError> + ) -> Result { unimplemented!("not used in tests") } + fn get_allowed_vpc_endpoint_ids( + &self, + ) -> Result { + unimplemented!("not used in tests") + + } + fn dyn_clone(&self) -> Box { Box::new(self.clone()) } diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 051bdef4bcea..d3448e5b7e41 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -58,8 +58,7 @@ impl PoolingBackend { let user_info = user_info.clone(); let backend = self.auth_backend.as_ref().map(|()| user_info.clone()); - let (allowed_ips, allowed_vpce_ids, maybe_secret) = - backend.get_allowed_ips_and_secret(ctx).await?; + let allowed_ips = backend.get_allowed_ips(ctx).await?; let extra = ctx.extra(); let incoming_endpoint_id = match extra { @@ -74,26 +73,25 @@ impl PoolingBackend { Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(), }; - if incoming_endpoint_id != "" && !allowed_vpce_ids.contains(&incoming_endpoint_id) { - return Err(AuthError::vpc_endpoint_id_not_allowed(incoming_endpoint_id)); - } - if self.config.authentication_config.ip_allowlist_check_enabled && !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) { return Err(AuthError::ip_address_not_allowed(ctx.peer_addr())); } + if incoming_endpoint_id != "" { + let allowed_vcp_endpoint_ids = backend.get_allowed_vpc_endpoint_ids(ctx).await?; + if !allowed_vcp_endpoint_ids.contains(&incoming_endpoint_id) { + return Err(AuthError::vpc_endpoint_id_not_allowed(incoming_endpoint_id)); + } + } + if !self .endpoint_rate_limiter .check(user_info.endpoint.clone().into(), 1) { return Err(AuthError::too_many_connections()); } - let cached_secret = match maybe_secret { - Some(secret) => secret, - None => backend.get_role_secret(ctx).await?, - }; - + let cached_secret = backend.get_role_secret(ctx).await?; let secret = match cached_secret.value.clone() { Some(secret) => self.config.authentication_config.check_rate_limit( ctx, From 9a3d60af8f491f2687342465bc145f59032aefbb Mon Sep 17 00:00:00 2001 From: Stefan Radig Date: Wed, 18 Dec 2024 10:11:15 +0100 Subject: [PATCH 04/16] Added metrics for VPC endpoint IDs --- proxy/src/control_plane/client/cplane_proxy_v1.rs | 9 ++++++--- proxy/src/metrics.rs | 7 +++++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/proxy/src/control_plane/client/cplane_proxy_v1.rs b/proxy/src/control_plane/client/cplane_proxy_v1.rs index fad2ec21fd76..14e8cfe21b4a 100644 --- a/proxy/src/control_plane/client/cplane_proxy_v1.rs +++ b/proxy/src/control_plane/client/cplane_proxy_v1.rs @@ -151,7 +151,10 @@ impl NeonControlPlaneClient { .allowed_ips_number .observe(allowed_ips.len() as f64); let allowed_vpc_endpoint_ids = body.allowed_vpc_endpoint_ids.unwrap_or_default(); - // TODO: Add metrics? + Metrics::get() + .proxy + .allowed_vpc_endpoint_ids + .observe(allowed_vpc_endpoint_ids.len() as f64); Ok(AuthInfo { secret, allowed_ips, @@ -376,14 +379,14 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { if let Some(allowed_vcp_endpoint_ids) = self.caches.project_info.get_allowed_vpc_endpoint_ids(normalized_ep) { Metrics::get() .proxy - .allowed_ips_cache_misses // TODO Replace with a dedicated variable + .vpc_endpoint_id_cache_stats .inc(CacheOutcome::Hit); return Ok(allowed_vcp_endpoint_ids); } Metrics::get() .proxy - .allowed_ips_cache_misses // TODO Replace with a dedicated variable + .vpc_endpoint_id_cache_stats .inc(CacheOutcome::Miss); let auth_info = self.do_get_auth_info(ctx, user_info).await?; diff --git a/proxy/src/metrics.rs b/proxy/src/metrics.rs index f3d281a26b59..3c523efc53d7 100644 --- a/proxy/src/metrics.rs +++ b/proxy/src/metrics.rs @@ -96,6 +96,13 @@ pub struct ProxyMetrics { #[metric(metadata = Thresholds::with_buckets([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0, 50.0, 100.0]))] pub allowed_ips_number: Histogram<10>, + /// Number of cache hits/misses for VPC endpoint IDs. + pub vpc_endpoint_id_cache_stats: CounterVec>, + + /// Number of allowed VPC endpoints IDs + #[metric(metadata = Thresholds::with_buckets([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0, 50.0, 100.0]))] + pub allowed_vpc_endpoint_ids: Histogram<10>, + /// Number of connections (per sni). pub accepted_connections_by_sni: CounterVec>, From d0f93c3ff9065a5cc01072afd600c853b03f6c4e Mon Sep 17 00:00:00 2001 From: Stefan Radig Date: Wed, 18 Dec 2024 10:15:20 +0100 Subject: [PATCH 05/16] For now of no VCP endpoints are allowed allow all --- proxy/src/auth/backend/mod.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/proxy/src/auth/backend/mod.rs b/proxy/src/auth/backend/mod.rs index 7eb7688f8a29..24af6e3b8532 100644 --- a/proxy/src/auth/backend/mod.rs +++ b/proxy/src/auth/backend/mod.rs @@ -310,7 +310,8 @@ async fn auth_quirks( }; if incoming_endpoint_id != "" { let allowed_vpc_endpoint_ids = api.get_allowed_vpc_endpoint_ids(ctx, &info).await?; - if !allowed_vpc_endpoint_ids.contains(&incoming_endpoint_id) { + // TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that. + if !allowed_vpc_endpoint_ids.is_empty() && !allowed_vpc_endpoint_ids.contains(&incoming_endpoint_id) { return Err(AuthError::vpc_endpoint_id_not_allowed(incoming_endpoint_id)); } } From cfd6ae1fea35e2e321dc389b498c1abf9cae2dbf Mon Sep 17 00:00:00 2001 From: Stefan Radig Date: Wed, 18 Dec 2024 10:41:38 +0100 Subject: [PATCH 06/16] Formatting fixed. --- proxy/src/auth/backend/mod.rs | 20 ++++++------- proxy/src/cache/project_info.rs | 17 ++++++++--- .../control_plane/client/cplane_proxy_v1.rs | 9 ++++-- proxy/src/control_plane/client/mock.rs | 30 +++++++++---------- proxy/src/control_plane/client/mod.rs | 7 ++--- proxy/src/control_plane/mod.rs | 5 ++-- proxy/src/proxy/tests/mod.rs | 12 +++----- 7 files changed, 54 insertions(+), 46 deletions(-) diff --git a/proxy/src/auth/backend/mod.rs b/proxy/src/auth/backend/mod.rs index 24af6e3b8532..f232ae82936b 100644 --- a/proxy/src/auth/backend/mod.rs +++ b/proxy/src/auth/backend/mod.rs @@ -311,7 +311,9 @@ async fn auth_quirks( if incoming_endpoint_id != "" { let allowed_vpc_endpoint_ids = api.get_allowed_vpc_endpoint_ids(ctx, &info).await?; // TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that. - if !allowed_vpc_endpoint_ids.is_empty() && !allowed_vpc_endpoint_ids.contains(&incoming_endpoint_id) { + if !allowed_vpc_endpoint_ids.is_empty() + && !allowed_vpc_endpoint_ids.contains(&incoming_endpoint_id) + { return Err(AuthError::vpc_endpoint_id_not_allowed(incoming_endpoint_id)); } } @@ -466,12 +468,8 @@ impl Backend<'_, ComputeUserInfo> { ctx: &RequestContext, ) -> Result { match self { - Self::ControlPlane(api, user_info) => { - api.get_allowed_ips(ctx, user_info).await - } - Self::Local(_) => Ok( - Cached::new_uncached(Arc::new(vec![])), - ), + Self::ControlPlane(api, user_info) => api.get_allowed_ips(ctx, user_info).await, + Self::Local(_) => Ok(Cached::new_uncached(Arc::new(vec![]))), } } @@ -483,9 +481,7 @@ impl Backend<'_, ComputeUserInfo> { Self::ControlPlane(api, user_info) => { api.get_allowed_vpc_endpoint_ids(ctx, user_info).await } - Self::Local(_) => Ok( - Cached::new_uncached(Arc::new(vec![])), - ), + Self::Local(_) => Ok(Cached::new_uncached(Arc::new(vec![]))), } } } @@ -588,7 +584,9 @@ mod tests { _ctx: &RequestContext, _user_info: &super::ComputeUserInfo, ) -> Result { - Ok(CachedAllowedVpcEndpointIds::new_uncached(Arc::new(self.vpc_endpoint_ids.clone()))) + Ok(CachedAllowedVpcEndpointIds::new_uncached(Arc::new( + self.vpc_endpoint_ids.clone(), + ))) } async fn get_endpoint_jwks( diff --git a/proxy/src/cache/project_info.rs b/proxy/src/cache/project_info.rs index 19a7bd0a4abc..37d9d91efbe3 100644 --- a/proxy/src/cache/project_info.rs +++ b/proxy/src/cache/project_info.rs @@ -102,7 +102,10 @@ impl EndpointInfo { if valid_since < allowed_vpc_endpoint_ids.created_at { return Some(( allowed_vpc_endpoint_ids.value.clone(), - Self::check_ignore_cache(ignore_cache_since, allowed_vpc_endpoint_ids.created_at), + Self::check_ignore_cache( + ignore_cache_since, + allowed_vpc_endpoint_ids.created_at, + ), )); } } @@ -256,7 +259,10 @@ impl ProjectInfoCacheImpl { let (value, ignore_cache) = value?; if !ignore_cache { let cached = Cached { - token: Some((self, CachedLookupInfo::new_allowed_vpc_endpoint_ids(endpoint_id))), + token: Some(( + self, + CachedLookupInfo::new_allowed_vpc_endpoint_ids(endpoint_id), + )), value, }; return Some(cached); @@ -304,7 +310,10 @@ impl ProjectInfoCacheImpl { return; } self.insert_project2endpoint(project_id, endpoint_id); - self.cache.entry(endpoint_id).or_default().allowed_vpc_endpoint_ids = Some(allowed_vpc_endpoint_ids.into()); + self.cache + .entry(endpoint_id) + .or_default() + .allowed_vpc_endpoint_ids = Some(allowed_vpc_endpoint_ids.into()); } fn insert_project2endpoint(&self, project_id: ProjectIdInt, endpoint_id: EndpointIdInt) { if let Some(mut endpoints) = self.project2ep.get_mut(&project_id) { @@ -395,7 +404,7 @@ impl CachedLookupInfo { enum LookupType { RoleSecret(RoleNameInt), AllowedIps, - AllowedVpcEndpointIds + AllowedVpcEndpointIds, } impl Cache for ProjectInfoCacheImpl { diff --git a/proxy/src/control_plane/client/cplane_proxy_v1.rs b/proxy/src/control_plane/client/cplane_proxy_v1.rs index 14e8cfe21b4a..e904fd31872e 100644 --- a/proxy/src/control_plane/client/cplane_proxy_v1.rs +++ b/proxy/src/control_plane/client/cplane_proxy_v1.rs @@ -22,7 +22,8 @@ use crate::control_plane::errors::{ use crate::control_plane::locks::ApiLocks; use crate::control_plane::messages::{ColdStartInfo, EndpointJwksResponse, Reason}; use crate::control_plane::{ - AuthInfo, AuthSecret, CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret, NodeInfo, + AuthInfo, AuthSecret, CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedNodeInfo, + CachedRoleSecret, NodeInfo, }; use crate::metrics::{CacheOutcome, Metrics}; use crate::rate_limiter::WakeComputeRateLimiter; @@ -376,7 +377,11 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { user_info: &ComputeUserInfo, ) -> Result { let normalized_ep = &user_info.endpoint.normalize(); - if let Some(allowed_vcp_endpoint_ids) = self.caches.project_info.get_allowed_vpc_endpoint_ids(normalized_ep) { + if let Some(allowed_vcp_endpoint_ids) = self + .caches + .project_info + .get_allowed_vpc_endpoint_ids(normalized_ep) + { Metrics::get() .proxy .vpc_endpoint_id_cache_stats diff --git a/proxy/src/control_plane/client/mock.rs b/proxy/src/control_plane/client/mock.rs index 40c60513973d..f2c11a379ebc 100644 --- a/proxy/src/control_plane/client/mock.rs +++ b/proxy/src/control_plane/client/mock.rs @@ -13,7 +13,9 @@ use crate::auth::backend::ComputeUserInfo; use crate::auth::IpPattern; use crate::cache::Cached; use crate::context::RequestContext; -use crate::control_plane::client::{CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedRoleSecret}; +use crate::control_plane::client::{ + CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedRoleSecret, +}; use crate::control_plane::errors::{ ControlPlaneError, GetAuthInfoError, GetEndpointJwksError, WakeComputeError, }; @@ -221,23 +223,21 @@ impl super::ControlPlaneApi for MockControlPlane { _ctx: &RequestContext, user_info: &ComputeUserInfo, ) -> Result { - Ok( - Cached::new_uncached(Arc::new( - self.do_get_auth_info(user_info).await?.allowed_ips, - )) - ) + Ok(Cached::new_uncached(Arc::new( + self.do_get_auth_info(user_info).await?.allowed_ips, + ))) } async fn get_allowed_vpc_endpoint_ids( - &self, - _ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - Ok( - Cached::new_uncached(Arc::new( - self.do_get_auth_info(user_info).await?.allowed_vpc_endpoint_ids, - )), - ) + &self, + _ctx: &RequestContext, + user_info: &ComputeUserInfo, + ) -> Result { + Ok(Cached::new_uncached(Arc::new( + self.do_get_auth_info(user_info) + .await? + .allowed_vpc_endpoint_ids, + ))) } async fn get_endpoint_jwks( diff --git a/proxy/src/control_plane/client/mod.rs b/proxy/src/control_plane/client/mod.rs index 47a95678a715..a6700a043a64 100644 --- a/proxy/src/control_plane/client/mod.rs +++ b/proxy/src/control_plane/client/mod.rs @@ -17,7 +17,8 @@ use crate::cache::project_info::ProjectInfoCacheImpl; use crate::config::{CacheOptions, EndpointCacheConfig, ProjectInfoCacheOptions}; use crate::context::RequestContext; use crate::control_plane::{ - errors, CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret, ControlPlaneApi, NodeInfoCache, + errors, CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret, + ControlPlaneApi, NodeInfoCache, }; use crate::error::ReportableError; use crate::metrics::ApiLockMetrics; @@ -116,9 +117,7 @@ impl ControlPlaneApi for ControlPlaneClient { pub(crate) trait TestControlPlaneClient: Send + Sync + 'static { fn wake_compute(&self) -> Result; - fn get_allowed_ips( - &self, - ) -> Result; + fn get_allowed_ips(&self) -> Result; fn get_allowed_vpc_endpoint_ids( &self, diff --git a/proxy/src/control_plane/mod.rs b/proxy/src/control_plane/mod.rs index 90de586ff53f..a7106b8d3848 100644 --- a/proxy/src/control_plane/mod.rs +++ b/proxy/src/control_plane/mod.rs @@ -53,7 +53,7 @@ pub(crate) struct AuthInfo { /// List of IP addresses allowed for the autorization. pub(crate) allowed_ips: Vec, /// List of VPC endpoints allowed for the autorization. - pub(crate) allowed_vpc_endpoint_ids: Vec, + pub(crate) allowed_vpc_endpoint_ids: Vec, /// Project ID. This is used for cache invalidation. pub(crate) project_id: Option, } @@ -102,7 +102,8 @@ pub(crate) type NodeInfoCache = pub(crate) type CachedNodeInfo = Cached<&'static NodeInfoCache, NodeInfo>; pub(crate) type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option>; pub(crate) type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc>>; -pub(crate) type CachedAllowedVpcEndpointIds = Cached<&'static ProjectInfoCacheImpl, Arc>>; +pub(crate) type CachedAllowedVpcEndpointIds = + Cached<&'static ProjectInfoCacheImpl, Arc>>; /// This will allocate per each call, but the http requests alone /// already require a few allocations, so it should be fine. diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index 8b00fa8586e3..5261033a51ee 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -526,18 +526,14 @@ impl TestControlPlaneClient for TestConnectMechanism { } } - fn get_allowed_ips( - &self, - ) -> Result - { + fn get_allowed_ips(&self) -> Result { unimplemented!("not used in tests") } fn get_allowed_vpc_endpoint_ids( - &self, - ) -> Result { - unimplemented!("not used in tests") - + &self, + ) -> Result { + unimplemented!("not used in tests") } fn dyn_clone(&self) -> Box { From 0d76a070cdcdbe4dfc32668e55338a3ba7eec949 Mon Sep 17 00:00:00 2001 From: Stefan Radig Date: Thu, 19 Dec 2024 10:09:00 +0100 Subject: [PATCH 07/16] Wired up invalidation of VPC endpoint IDs and block public / private. --- proxy/src/cache/project_info.rs | 99 ++++++++++++++++++- .../control_plane/client/cplane_proxy_v1.rs | 25 +++-- proxy/src/control_plane/messages.rs | 7 +- proxy/src/control_plane/mod.rs | 7 ++ proxy/src/intern.rs | 22 ++++- proxy/src/metrics.rs | 3 + proxy/src/redis/notifications.rs | 55 ++++++++--- proxy/src/serverless/backend.rs | 6 +- proxy/src/types.rs | 2 + 9 files changed, 198 insertions(+), 28 deletions(-) diff --git a/proxy/src/cache/project_info.rs b/proxy/src/cache/project_info.rs index 37d9d91efbe3..aa6a9b4772cd 100644 --- a/proxy/src/cache/project_info.rs +++ b/proxy/src/cache/project_info.rs @@ -16,12 +16,15 @@ use super::{Cache, Cached}; use crate::auth::IpPattern; use crate::config::ProjectInfoCacheOptions; use crate::control_plane::AuthSecret; -use crate::intern::{EndpointIdInt, ProjectIdInt, RoleNameInt}; +use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt}; use crate::types::{EndpointId, RoleName}; #[async_trait] pub(crate) trait ProjectInfoCache { fn invalidate_allowed_ips_for_project(&self, project_id: ProjectIdInt); + fn invalidate_allowed_vpc_endpoint_ids_for_projects(&self, project_ids: Vec); + fn invalidate_allowed_vpc_endpoint_ids_for_org(&self, account_id: AccountIdInt); + fn invalidate_block_public_or_vpc_access_for_project(&self, project_id: ProjectIdInt); fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt); async fn decrement_active_listeners(&self); async fn increment_active_listeners(&self); @@ -51,6 +54,7 @@ impl From for Entry { struct EndpointInfo { secret: std::collections::HashMap>>, allowed_ips: Option>>>, + block_public_or_vpc_access: Option>, allowed_vpc_endpoint_ids: Option>>>, } @@ -111,12 +115,34 @@ impl EndpointInfo { } None } + pub(crate) fn get_block_public_or_vpc_access( + &self, + valid_since: Instant, + ignore_cache_since: Option, + ) -> Option<((bool, bool), bool)> { + if let Some(block_public_or_vpc_access) = &self.block_public_or_vpc_access { + if valid_since < block_public_or_vpc_access.created_at { + return Some(( + block_public_or_vpc_access.value.clone(), + Self::check_ignore_cache( + ignore_cache_since, + block_public_or_vpc_access.created_at, + ), + )); + } + } + None + } + pub(crate) fn invalidate_allowed_ips(&mut self) { self.allowed_ips = None; } pub(crate) fn invalidate_allowed_vpc_endpoint_ids(&mut self) { self.allowed_vpc_endpoint_ids = None; } + pub(crate) fn invalidate_block_public_or_vpc_access(&mut self) { + self.block_public_or_vpc_access = None; + } pub(crate) fn invalidate_role_secret(&mut self, role_name: RoleNameInt) { self.secret.remove(&role_name); } @@ -133,6 +159,8 @@ pub struct ProjectInfoCacheImpl { cache: ClashMap, project2ep: ClashMap>, + // FIXME(stefan): we need a way to GC the account2ep map. + account2ep: ClashMap>, config: ProjectInfoCacheOptions, start_time: Instant, @@ -142,6 +170,63 @@ pub struct ProjectInfoCacheImpl { #[async_trait] impl ProjectInfoCache for ProjectInfoCacheImpl { + fn invalidate_allowed_vpc_endpoint_ids_for_projects(&self, project_ids: Vec) { + info!( + "invalidating allowed vpc endpoint ids for projects `{}`", + project_ids + .iter() + .map(|id| id.to_string()) + .collect::>() + .join(", ") + ); + for project_id in project_ids { + let endpoints = self + .project2ep + .get(&project_id) + .map(|kv| kv.value().clone()) + .unwrap_or_default(); + for endpoint_id in endpoints { + if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { + endpoint_info.invalidate_allowed_vpc_endpoint_ids(); + } + } + } + } + + fn invalidate_allowed_vpc_endpoint_ids_for_org(&self, account_id: AccountIdInt) { + info!( + "invalidating allowed vpc endpoint ids for org `{}`", + account_id + ); + let endpoints = self + .account2ep + .get(&account_id) + .map(|kv| kv.value().clone()) + .unwrap_or_default(); + for endpoint_id in endpoints { + if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { + endpoint_info.invalidate_allowed_vpc_endpoint_ids(); + } + } + } + + fn invalidate_block_public_or_vpc_access_for_project(&self, project_id: ProjectIdInt) { + info!( + "invalidating block public or vpc access for project `{}`", + project_id + ); + let endpoints = self + .project2ep + .get(&project_id) + .map(|kv| kv.value().clone()) + .unwrap_or_default(); + for endpoint_id in endpoints { + if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { + endpoint_info.invalidate_block_public_or_vpc_access(); + } + } + } + fn invalidate_allowed_ips_for_project(&self, project_id: ProjectIdInt) { info!("invalidating allowed ips for project `{}`", project_id); let endpoints = self @@ -200,6 +285,7 @@ impl ProjectInfoCacheImpl { Self { cache: ClashMap::new(), project2ep: ClashMap::new(), + account2ep: ClashMap::new(), config, ttl_disabled_since_us: AtomicU64::new(u64::MAX), start_time: Instant::now(), @@ -301,6 +387,7 @@ impl ProjectInfoCacheImpl { } pub(crate) fn insert_allowed_vpc_endpoint_ids( &self, + account_id: Option, project_id: ProjectIdInt, endpoint_id: EndpointIdInt, allowed_vpc_endpoint_ids: Arc>, @@ -309,6 +396,9 @@ impl ProjectInfoCacheImpl { // If there are too many entries, wait until the next gc cycle. return; } + if let Some(account_id) = account_id { + self.insert_account2endpoint(account_id, endpoint_id); + } self.insert_project2endpoint(project_id, endpoint_id); self.cache .entry(endpoint_id) @@ -323,6 +413,13 @@ impl ProjectInfoCacheImpl { .insert(project_id, HashSet::from([endpoint_id])); } } + fn insert_account2endpoint(&self, account_id: AccountIdInt, endpoint_id: EndpointIdInt) { + if let Some(mut endpoints) = self.account2ep.get_mut(&account_id) { + endpoints.insert(endpoint_id); + } else { + self.account2ep.insert(account_id, HashSet::from([endpoint_id])); + } + } fn get_cache_times(&self) -> (Instant, Option) { let mut valid_since = Instant::now() - self.config.ttl; // Only ignore cache if ttl is disabled. diff --git a/proxy/src/control_plane/client/cplane_proxy_v1.rs b/proxy/src/control_plane/client/cplane_proxy_v1.rs index e904fd31872e..923cc0260024 100644 --- a/proxy/src/control_plane/client/cplane_proxy_v1.rs +++ b/proxy/src/control_plane/client/cplane_proxy_v1.rs @@ -156,11 +156,16 @@ impl NeonControlPlaneClient { .proxy .allowed_vpc_endpoint_ids .observe(allowed_vpc_endpoint_ids.len() as f64); + let block_public_connections = body.block_public_connections.unwrap_or_default(); + let block_vpc_connections = body.block_vpc_connections.unwrap_or_default(); Ok(AuthInfo { secret, allowed_ips, allowed_vpc_endpoint_ids, project_id: body.project_id, + account_id: body.account_id, + block_public_connections, + block_private_connections: block_vpc_connections, }) } .inspect_err(|e| tracing::debug!(error = ?e)) @@ -303,6 +308,7 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { return Ok(role_secret); } let auth_info = self.do_get_auth_info(ctx, user_info).await?; + let account_id = auth_info.account_id; if let Some(project_id) = auth_info.project_id { let normalized_ep_int = normalized_ep.into(); self.caches.project_info.insert_role_secret( @@ -317,6 +323,7 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { Arc::new(auth_info.allowed_ips), ); self.caches.project_info.insert_allowed_vpc_endpoint_ids( + account_id, project_id, normalized_ep_int, Arc::new(auth_info.allowed_vpc_endpoint_ids), @@ -346,8 +353,9 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { .inc(CacheOutcome::Miss); let auth_info = self.do_get_auth_info(ctx, user_info).await?; let allowed_ips = Arc::new(auth_info.allowed_ips); - let allowed_vcp_endpoint_ids = Arc::new(auth_info.allowed_vpc_endpoint_ids); + let allowed_vpc_endpoint_ids = Arc::new(auth_info.allowed_vpc_endpoint_ids); let user = &user_info.user; + let account_id = auth_info.account_id; if let Some(project_id) = auth_info.project_id { let normalized_ep_int = normalized_ep.into(); self.caches.project_info.insert_role_secret( @@ -362,9 +370,10 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { allowed_ips.clone(), ); self.caches.project_info.insert_allowed_vpc_endpoint_ids( + account_id, project_id, normalized_ep_int, - allowed_vcp_endpoint_ids.clone(), + allowed_vpc_endpoint_ids.clone(), ); ctx.set_project_id(project_id); } @@ -377,7 +386,7 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { user_info: &ComputeUserInfo, ) -> Result { let normalized_ep = &user_info.endpoint.normalize(); - if let Some(allowed_vcp_endpoint_ids) = self + if let Some(allowed_vpc_endpoint_ids) = self .caches .project_info .get_allowed_vpc_endpoint_ids(normalized_ep) @@ -386,7 +395,7 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { .proxy .vpc_endpoint_id_cache_stats .inc(CacheOutcome::Hit); - return Ok(allowed_vcp_endpoint_ids); + return Ok(allowed_vpc_endpoint_ids); } Metrics::get() @@ -396,8 +405,9 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { let auth_info = self.do_get_auth_info(ctx, user_info).await?; let allowed_ips = Arc::new(auth_info.allowed_ips); - let allowed_vcp_endpoint_ids = Arc::new(auth_info.allowed_vpc_endpoint_ids); + let allowed_vpc_endpoint_ids = Arc::new(auth_info.allowed_vpc_endpoint_ids); let user = &user_info.user; + let account_id = auth_info.account_id; if let Some(project_id) = auth_info.project_id { let normalized_ep_int = normalized_ep.into(); self.caches.project_info.insert_role_secret( @@ -412,13 +422,14 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { allowed_ips.clone(), ); self.caches.project_info.insert_allowed_vpc_endpoint_ids( + account_id, project_id, normalized_ep_int, - allowed_vcp_endpoint_ids.clone(), + allowed_vpc_endpoint_ids.clone(), ); ctx.set_project_id(project_id); } - Ok(Cached::new_uncached(allowed_vcp_endpoint_ids)) + Ok(Cached::new_uncached(allowed_vpc_endpoint_ids)) } #[tracing::instrument(skip_all)] diff --git a/proxy/src/control_plane/messages.rs b/proxy/src/control_plane/messages.rs index a863072d9ab5..d75b51d3078d 100644 --- a/proxy/src/control_plane/messages.rs +++ b/proxy/src/control_plane/messages.rs @@ -4,7 +4,7 @@ use measured::FixedCardinalityLabel; use serde::{Deserialize, Serialize}; use crate::auth::IpPattern; -use crate::intern::{BranchIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt}; +use crate::intern::{AccountIdInt, BranchIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt}; use crate::proxy::retry::CouldRetry; /// Generic error response with human-readable description. @@ -229,6 +229,9 @@ pub(crate) struct GetEndpointAccessControl { pub(crate) allowed_ips: Option>, pub(crate) allowed_vpc_endpoint_ids: Option>, pub(crate) project_id: Option, + pub(crate) account_id: Option, + pub(crate) block_public_connections: Option, + pub(crate) block_vpc_connections: Option, } /// Response which holds compute node's `host:port` pair. @@ -460,7 +463,7 @@ mod tests { #[test] fn parse_get_role_secret() -> anyhow::Result<()> { - // Empty `allowed_ips` and `allowed_vcp_endpoint_ids` field. + // Empty `allowed_ips` and `allowed_vpc_endpoint_ids` field. let json = json!({ "role_secret": "secret", }); diff --git a/proxy/src/control_plane/mod.rs b/proxy/src/control_plane/mod.rs index a7106b8d3848..a3db2b5e7b67 100644 --- a/proxy/src/control_plane/mod.rs +++ b/proxy/src/control_plane/mod.rs @@ -19,6 +19,7 @@ use crate::cache::{Cached, TimedLru}; use crate::config::ComputeConfig; use crate::context::RequestContext; use crate::control_plane::messages::{ControlPlaneErrorMessage, MetricsAuxInfo}; +use crate::intern::AccountIdInt; use crate::intern::ProjectIdInt; use crate::types::{EndpointCacheKey, EndpointId}; use crate::{compute, scram}; @@ -56,6 +57,12 @@ pub(crate) struct AuthInfo { pub(crate) allowed_vpc_endpoint_ids: Vec, /// Project ID. This is used for cache invalidation. pub(crate) project_id: Option, + /// Account ID. This is used for cache invalidation. + pub(crate) account_id: Option, + /// Are public connections blocked? + pub(crate) block_public_connections: bool, + /// Are private connections blocked? + pub(crate) block_private_connections: bool, } /// Info for establishing a connection to a compute node. diff --git a/proxy/src/intern.rs b/proxy/src/intern.rs index 79c6020302af..0d1382679c81 100644 --- a/proxy/src/intern.rs +++ b/proxy/src/intern.rs @@ -7,7 +7,7 @@ use std::sync::OnceLock; use lasso::{Capacity, MemoryLimits, Spur, ThreadedRodeo}; use rustc_hash::FxHasher; -use crate::types::{BranchId, EndpointId, ProjectId, RoleName}; +use crate::types::{AccountId, BranchId, EndpointId, ProjectId, RoleName}; pub trait InternId: Sized + 'static { fn get_interner() -> &'static StringInterner; @@ -206,6 +206,26 @@ impl From for ProjectIdInt { } } +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +pub struct AccountIdTag; +impl InternId for AccountIdTag { + fn get_interner() -> &'static StringInterner { + static ROLE_NAMES: OnceLock> = OnceLock::new(); + ROLE_NAMES.get_or_init(Default::default) + } +} +pub type AccountIdInt = InternedString; +impl From<&AccountId> for AccountIdInt { + fn from(value: &AccountId) -> Self { + AccountIdTag::get_interner().get_or_intern(value) + } +} +impl From for AccountIdInt { + fn from(value: AccountId) -> Self { + AccountIdTag::get_interner().get_or_intern(&value) + } +} + #[cfg(test)] #[expect(clippy::unwrap_used)] mod tests { diff --git a/proxy/src/metrics.rs b/proxy/src/metrics.rs index 3c523efc53d7..e2fa927a4e5b 100644 --- a/proxy/src/metrics.rs +++ b/proxy/src/metrics.rs @@ -577,6 +577,9 @@ pub enum RedisEventsCount { CancelSession, PasswordUpdate, AllowedIpsUpdate, + AllowedVpcEndpointIdsUpdateForProjects, + AllowedVpcEndpointIdsUpdateForAllProjectsInOrg, + BlockPublicOrVpcAccessUpdate, } pub struct ThreadPoolWorkers(usize); diff --git a/proxy/src/redis/notifications.rs b/proxy/src/redis/notifications.rs index 19fdd3280dfc..9bd9fecb3b17 100644 --- a/proxy/src/redis/notifications.rs +++ b/proxy/src/redis/notifications.rs @@ -10,7 +10,7 @@ use uuid::Uuid; use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider; use crate::cache::project_info::ProjectInfoCache; -use crate::intern::{ProjectIdInt, RoleNameInt}; +use crate::intern::{ProjectIdInt, RoleNameInt, AccountIdInt}; use crate::metrics::{Metrics, RedisErrors, RedisEventsCount}; const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates"; @@ -86,9 +86,7 @@ pub(crate) struct BlockPublicOrVpcAccessUpdated { #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] pub(crate) struct AllowedVpcEndpointsUpdatedForOrg { - // TODO: change type once the implementation is more fully fledged. - // See e.g. https://github.com/neondatabase/neon/pull/10073. - account_id: ProjectIdInt, + account_id: AccountIdInt, } #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] @@ -205,6 +203,27 @@ impl MessageHandler { .proxy .redis_events_count .inc(RedisEventsCount::PasswordUpdate); + } else if matches!( + msg, + Notification::AllowedVpcEndpointsUpdatedForProjects { .. } + ) { + Metrics::get() + .proxy + .redis_events_count + .inc(RedisEventsCount::AllowedVpcEndpointIdsUpdateForProjects); + } else if matches!( + msg, + Notification::AllowedVpcEndpointsUpdatedForOrg { .. } + ) { + Metrics::get() + .proxy + .redis_events_count + .inc(RedisEventsCount::AllowedVpcEndpointIdsUpdateForAllProjectsInOrg); + } else if matches!(msg, Notification::BlockPublicOrVpcAccessUpdated { .. }) { + Metrics::get() + .proxy + .redis_events_count + .inc(RedisEventsCount::BlockPublicOrVpcAccessUpdate); } // TODO: add additional metrics for the other event types. @@ -229,21 +248,27 @@ fn invalidate_cache(cache: Arc, msg: Notification) { match msg { Notification::AllowedIpsUpdate { allowed_ips_update } => { cache.invalidate_allowed_ips_for_project(allowed_ips_update.project_id); - } + }, + Notification::BlockPublicOrVpcAccessUpdated { + block_public_or_vpc_access_updated, + } => cache.invalidate_block_public_or_vpc_access_for_project( + block_public_or_vpc_access_updated.project_id, + ), + Notification::AllowedVpcEndpointsUpdatedForOrg { + allowed_vpc_endpoints_updated_for_org, + } => cache.invalidate_allowed_vpc_endpoint_ids_for_org( + allowed_vpc_endpoints_updated_for_org.account_id, + ), + Notification::AllowedVpcEndpointsUpdatedForProjects { + allowed_vpc_endpoints_updated_for_projects, + } => cache.invalidate_allowed_vpc_endpoint_ids_for_projects( + allowed_vpc_endpoints_updated_for_projects.project_ids, + ), Notification::PasswordUpdate { password_update } => cache .invalidate_role_secret_for_project( password_update.project_id, password_update.role_name, - ), - Notification::BlockPublicOrVpcAccessUpdated { .. } => { - // https://github.com/neondatabase/neon/pull/10073 - } - Notification::AllowedVpcEndpointsUpdatedForOrg { .. } => { - // https://github.com/neondatabase/neon/pull/10073 - } - Notification::AllowedVpcEndpointsUpdatedForProjects { .. } => { - // https://github.com/neondatabase/neon/pull/10073 - } + ), Notification::UnknownTopic => unreachable!(), } } diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index d3448e5b7e41..d427d926d65c 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -79,8 +79,10 @@ impl PoolingBackend { return Err(AuthError::ip_address_not_allowed(ctx.peer_addr())); } if incoming_endpoint_id != "" { - let allowed_vcp_endpoint_ids = backend.get_allowed_vpc_endpoint_ids(ctx).await?; - if !allowed_vcp_endpoint_ids.contains(&incoming_endpoint_id) { + let allowed_vpc_endpoint_ids = backend.get_allowed_vpc_endpoint_ids(ctx).await?; + // TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that. + if !allowed_vpc_endpoint_ids.is_empty() && + !allowed_vpc_endpoint_ids.contains(&incoming_endpoint_id) { return Err(AuthError::vpc_endpoint_id_not_allowed(incoming_endpoint_id)); } } diff --git a/proxy/src/types.rs b/proxy/src/types.rs index 6e0bd61c9442..d5952d1d8b0a 100644 --- a/proxy/src/types.rs +++ b/proxy/src/types.rs @@ -97,6 +97,8 @@ smol_str_wrapper!(EndpointId); smol_str_wrapper!(BranchId); // 90% of project strings are 23 characters or less. smol_str_wrapper!(ProjectId); +// 90% of account strings are 23 characters or less. +smol_str_wrapper!(AccountId); // will usually equal endpoint ID smol_str_wrapper!(EndpointCacheKey); From 951cc4849bd83e9ebfa183ca66c0dcd482f0c7f7 Mon Sep 17 00:00:00 2001 From: Stefan Radig Date: Thu, 19 Dec 2024 10:09:52 +0100 Subject: [PATCH 08/16] Formatting fixes. --- proxy/src/cache/project_info.rs | 3 ++- proxy/src/serverless/backend.rs | 7 ++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/proxy/src/cache/project_info.rs b/proxy/src/cache/project_info.rs index aa6a9b4772cd..05d5ab51be9f 100644 --- a/proxy/src/cache/project_info.rs +++ b/proxy/src/cache/project_info.rs @@ -417,7 +417,8 @@ impl ProjectInfoCacheImpl { if let Some(mut endpoints) = self.account2ep.get_mut(&account_id) { endpoints.insert(endpoint_id); } else { - self.account2ep.insert(account_id, HashSet::from([endpoint_id])); + self.account2ep + .insert(account_id, HashSet::from([endpoint_id])); } } fn get_cache_times(&self) -> (Instant, Option) { diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index d427d926d65c..ffcafeb246b5 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -80,9 +80,10 @@ impl PoolingBackend { } if incoming_endpoint_id != "" { let allowed_vpc_endpoint_ids = backend.get_allowed_vpc_endpoint_ids(ctx).await?; - // TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that. - if !allowed_vpc_endpoint_ids.is_empty() && - !allowed_vpc_endpoint_ids.contains(&incoming_endpoint_id) { + // TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that. + if !allowed_vpc_endpoint_ids.is_empty() + && !allowed_vpc_endpoint_ids.contains(&incoming_endpoint_id) + { return Err(AuthError::vpc_endpoint_id_not_allowed(incoming_endpoint_id)); } } From 8e51a4db573765dc0ef062a6ebe6754cce8a59f6 Mon Sep 17 00:00:00 2001 From: Stefan Radig Date: Thu, 19 Dec 2024 13:43:50 +0100 Subject: [PATCH 09/16] Wired up flags to block access over public internet or VPC --- proxy/src/auth/backend/mod.rs | 68 +++++++++++---- proxy/src/auth/mod.rs | 15 ++++ proxy/src/bin/local_proxy.rs | 1 + proxy/src/bin/proxy.rs | 1 + proxy/src/cache/project_info.rs | 57 ++++++++++++- proxy/src/config.rs | 1 + .../control_plane/client/cplane_proxy_v1.rs | 83 ++++++++++++++++++- proxy/src/control_plane/client/mock.rs | 10 +++ proxy/src/control_plane/client/mod.rs | 20 ++++- proxy/src/control_plane/mod.rs | 20 ++++- proxy/src/metrics.rs | 3 + proxy/src/proxy/tests/mod.rs | 6 ++ proxy/src/serverless/backend.rs | 42 ++++++---- 13 files changed, 286 insertions(+), 41 deletions(-) diff --git a/proxy/src/auth/backend/mod.rs b/proxy/src/auth/backend/mod.rs index f232ae82936b..ad98d66c93fd 100644 --- a/proxy/src/auth/backend/mod.rs +++ b/proxy/src/auth/backend/mod.rs @@ -26,7 +26,7 @@ use crate::context::RequestContext; use crate::control_plane::client::ControlPlaneClient; use crate::control_plane::errors::GetAuthInfoError; use crate::control_plane::{ - self, AuthSecret, CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedNodeInfo, + self, AuthSecret, CachedAccessBlockerFlags, CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret, ControlPlaneApi, }; use crate::intern::EndpointIdInt; @@ -295,26 +295,37 @@ async fn auth_quirks( } // check if a VPC endpoint ID is coming in and if yes, if it's allowed - // TODO: Add flag to enable/disable VPC endpoint ID check - let extra = ctx.extra(); - let incoming_endpoint_id = match extra { - None => "".to_string(), - Some(ConnectionInfoExtra::Aws { vpce_id }) => { - // Convert the vcpe_id to a string - match String::from_utf8(vpce_id.to_vec()) { - Ok(s) => s, - Err(_e) => "".to_string(), + let access_blocks = api.get_block_public_or_vpc_access(ctx, &info).await?; + if config.is_vpc_acccess_proxy { + if access_blocks.vpc_access_blocked { + return Err(AuthError::NetworkNotAllowed); + } + let extra = ctx.extra(); + let incoming_vpc_endpoint_id = match extra { + None => "".to_string(), + Some(ConnectionInfoExtra::Aws { vpce_id }) => { + // Convert the vcpe_id to a string + match String::from_utf8(vpce_id.to_vec()) { + Ok(s) => s, + Err(_e) => "".to_string(), + } } + Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(), + }; + if incoming_vpc_endpoint_id == "" { + // This should never happen, would be a setup error on our side. + return Err(AuthError::MissingEndpointName); } - Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(), - }; - if incoming_endpoint_id != "" { let allowed_vpc_endpoint_ids = api.get_allowed_vpc_endpoint_ids(ctx, &info).await?; // TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that. if !allowed_vpc_endpoint_ids.is_empty() - && !allowed_vpc_endpoint_ids.contains(&incoming_endpoint_id) + && !allowed_vpc_endpoint_ids.contains(&incoming_vpc_endpoint_id) { - return Err(AuthError::vpc_endpoint_id_not_allowed(incoming_endpoint_id)); + return Err(AuthError::vpc_endpoint_id_not_allowed(incoming_vpc_endpoint_id)); + } + } else { + if access_blocks.public_access_blocked { + return Err(AuthError::NetworkNotAllowed); } } @@ -484,6 +495,16 @@ impl Backend<'_, ComputeUserInfo> { Self::Local(_) => Ok(Cached::new_uncached(Arc::new(vec![]))), } } + + pub(crate) async fn get_block_public_or_vpc_access( + &self, + ctx: &RequestContext, + ) -> Result { + match self { + Self::ControlPlane(api, user_info) => api.get_block_public_or_vpc_access(ctx, user_info).await, + Self::Local(_) => Ok(Cached::new_uncached(Default::default())), + } + } } #[async_trait::async_trait] @@ -548,7 +569,7 @@ mod tests { use crate::config::AuthenticationConfig; use crate::context::RequestContext; use crate::control_plane::{ - self, CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret, + self, AccessBlockerFlags, CachedAccessBlockerFlags, CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret, }; use crate::proxy::NeonOptions; use crate::rate_limiter::{EndpointRateLimiter, RateBucketInfo}; @@ -559,6 +580,7 @@ mod tests { struct Auth { ips: Vec, vpc_endpoint_ids: Vec, + access_blocker_flags: AccessBlockerFlags, secret: AuthSecret, } @@ -589,6 +611,16 @@ mod tests { ))) } + async fn get_block_public_or_vpc_access( + &self, + _ctx: &RequestContext, + _user_info: &super::ComputeUserInfo, + ) -> Result { + Ok(CachedAccessBlockerFlags::new_uncached( + self.access_blocker_flags.clone() + )) + } + async fn get_endpoint_jwks( &self, _ctx: &RequestContext, @@ -615,6 +647,7 @@ mod tests { rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET), rate_limit_ip_subnet: 64, ip_allowlist_check_enabled: true, + is_vpc_acccess_proxy: false, is_auth_broker: false, accept_jwts: false, console_redirect_confirmation_timeout: std::time::Duration::from_secs(5), @@ -683,6 +716,7 @@ mod tests { let api = Auth { ips: vec![], vpc_endpoint_ids: vec![], + access_blocker_flags: AccessBlockerFlags::default(), secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()), }; @@ -764,6 +798,7 @@ mod tests { let api = Auth { ips: vec![], vpc_endpoint_ids: vec![], + access_blocker_flags: AccessBlockerFlags::default(), secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()), }; @@ -817,6 +852,7 @@ mod tests { let api = Auth { ips: vec![], vpc_endpoint_ids: vec![], + access_blocker_flags: AccessBlockerFlags::default(), secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()), }; diff --git a/proxy/src/auth/mod.rs b/proxy/src/auth/mod.rs index 987e5de25d56..6ef40b3cc603 100644 --- a/proxy/src/auth/mod.rs +++ b/proxy/src/auth/mod.rs @@ -55,6 +55,12 @@ pub(crate) enum AuthError { )] MissingEndpointName, + #[error( + "VPC endpoint ID is not specified. \ + This endpoint requires a VPC endpoint ID to connect." + )] + MissingVPCEndpointId, + #[error("password authentication failed for user '{0}'")] PasswordFailed(Box), @@ -69,6 +75,11 @@ pub(crate) enum AuthError { )] IpAddressNotAllowed(IpAddr), + #[error( + "This connection is trying to access this endpoint from a blocked network." + )] + NetworkNotAllowed, + #[error( "This VPC endpoint id {0} is not allowed to connect to this endpoint. \ Please add it to the allowed list in the Neon console." @@ -132,8 +143,10 @@ impl UserFacingError for AuthError { Self::BadAuthMethod(_) => self.to_string(), Self::MalformedPassword(_) => self.to_string(), Self::MissingEndpointName => self.to_string(), + Self::MissingVPCEndpointId => self.to_string(), Self::Io(_) => "Internal error".to_string(), Self::IpAddressNotAllowed(_) => self.to_string(), + Self::NetworkNotAllowed => self.to_string(), Self::VpcEndpointIdNotAllowed(_) => self.to_string(), Self::TooManyConnections => self.to_string(), Self::UserTimeout(_) => self.to_string(), @@ -153,8 +166,10 @@ impl ReportableError for AuthError { Self::BadAuthMethod(_) => crate::error::ErrorKind::User, Self::MalformedPassword(_) => crate::error::ErrorKind::User, Self::MissingEndpointName => crate::error::ErrorKind::User, + Self::MissingVPCEndpointId => crate::error::ErrorKind::User, Self::Io(_) => crate::error::ErrorKind::ClientDisconnect, Self::IpAddressNotAllowed(_) => crate::error::ErrorKind::User, + Self::NetworkNotAllowed => crate::error::ErrorKind::User, Self::VpcEndpointIdNotAllowed(_) => crate::error::ErrorKind::User, Self::TooManyConnections => crate::error::ErrorKind::RateLimit, Self::UserTimeout(_) => crate::error::ErrorKind::User, diff --git a/proxy/src/bin/local_proxy.rs b/proxy/src/bin/local_proxy.rs index ee8b3d4ef579..7a855bf54b41 100644 --- a/proxy/src/bin/local_proxy.rs +++ b/proxy/src/bin/local_proxy.rs @@ -284,6 +284,7 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig rate_limiter: BucketRateLimiter::new(vec![]), rate_limit_ip_subnet: 64, ip_allowlist_check_enabled: true, + is_vpc_acccess_proxy: false, is_auth_broker: false, accept_jwts: true, console_redirect_confirmation_timeout: Duration::ZERO, diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index e1affe8391a6..de685a82c627 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -630,6 +630,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()), rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet, ip_allowlist_check_enabled: !args.is_private_access_proxy, + is_vpc_acccess_proxy: args.is_private_access_proxy, is_auth_broker: args.is_auth_broker, accept_jwts: args.is_auth_broker, console_redirect_confirmation_timeout: args.webauth_confirmation_timeout, diff --git a/proxy/src/cache/project_info.rs b/proxy/src/cache/project_info.rs index 05d5ab51be9f..7651eb71a2e0 100644 --- a/proxy/src/cache/project_info.rs +++ b/proxy/src/cache/project_info.rs @@ -15,7 +15,7 @@ use tracing::{debug, info}; use super::{Cache, Cached}; use crate::auth::IpPattern; use crate::config::ProjectInfoCacheOptions; -use crate::control_plane::AuthSecret; +use crate::control_plane::{AccessBlockerFlags, AuthSecret}; use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt}; use crate::types::{EndpointId, RoleName}; @@ -54,7 +54,7 @@ impl From for Entry { struct EndpointInfo { secret: std::collections::HashMap>>, allowed_ips: Option>>>, - block_public_or_vpc_access: Option>, + block_public_or_vpc_access: Option>, allowed_vpc_endpoint_ids: Option>>>, } @@ -119,7 +119,7 @@ impl EndpointInfo { &self, valid_since: Instant, ignore_cache_since: Option, - ) -> Option<((bool, bool), bool)> { + ) -> Option<(AccessBlockerFlags, bool)> { if let Some(block_public_or_vpc_access) = &self.block_public_or_vpc_access { if valid_since < block_public_or_vpc_access.created_at { return Some(( @@ -355,6 +355,28 @@ impl ProjectInfoCacheImpl { } Some(Cached::new_uncached(value)) } + pub(crate) fn get_block_public_or_vpc_access( + &self, + endpoint_id: &EndpointId, + ) -> Option> { + let endpoint_id = EndpointIdInt::get(endpoint_id)?; + let (valid_since, ignore_cache_since) = self.get_cache_times(); + let endpoint_info = self.cache.get(&endpoint_id)?; + let value = endpoint_info.get_block_public_or_vpc_access(valid_since, ignore_cache_since); + let (value, ignore_cache) = value?; + if !ignore_cache { + let cached = Cached { + token: Some(( + self, + CachedLookupInfo::new_block_public_or_vpc_access(endpoint_id), + )), + value, + }; + return Some(cached); + } + Some(Cached::new_uncached(value)) + } + pub(crate) fn insert_role_secret( &self, project_id: ProjectIdInt, @@ -405,6 +427,23 @@ impl ProjectInfoCacheImpl { .or_default() .allowed_vpc_endpoint_ids = Some(allowed_vpc_endpoint_ids.into()); } + pub(crate) fn insert_block_public_or_vpc_access( + &self, + project_id: ProjectIdInt, + endpoint_id: EndpointIdInt, + access_blockers: AccessBlockerFlags, + ) { + if self.cache.len() >= self.config.size { + // If there are too many entries, wait until the next gc cycle. + return; + } + self.insert_project2endpoint(project_id, endpoint_id); + self.cache + .entry(endpoint_id) + .or_default() + .block_public_or_vpc_access = Some(access_blockers.into()); + } + fn insert_project2endpoint(&self, project_id: ProjectIdInt, endpoint_id: EndpointIdInt) { if let Some(mut endpoints) = self.project2ep.get_mut(&project_id) { endpoints.insert(endpoint_id); @@ -497,12 +536,19 @@ impl CachedLookupInfo { lookup_type: LookupType::AllowedVpcEndpointIds, } } + pub(self) fn new_block_public_or_vpc_access(endpoint_id: EndpointIdInt) -> Self { + Self { + endpoint_id, + lookup_type: LookupType::BlockPublicOrVpcAccess, + } + } } enum LookupType { RoleSecret(RoleNameInt), AllowedIps, AllowedVpcEndpointIds, + BlockPublicOrVpcAccess, } impl Cache for ProjectInfoCacheImpl { @@ -529,6 +575,11 @@ impl Cache for ProjectInfoCacheImpl { endpoint_info.invalidate_allowed_vpc_endpoint_ids(); } } + LookupType::BlockPublicOrVpcAccess => { + if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) { + endpoint_info.invalidate_block_public_or_vpc_access(); + } + } } } } diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 8502edcfab09..1dcd37712ea2 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -68,6 +68,7 @@ pub struct AuthenticationConfig { pub rate_limiter: AuthRateLimiter, pub rate_limit_ip_subnet: u8, pub ip_allowlist_check_enabled: bool, + pub is_vpc_acccess_proxy: bool, pub jwks_cache: JwkCache, pub is_auth_broker: bool, pub accept_jwts: bool, diff --git a/proxy/src/control_plane/client/cplane_proxy_v1.rs b/proxy/src/control_plane/client/cplane_proxy_v1.rs index 923cc0260024..1d170f5ef468 100644 --- a/proxy/src/control_plane/client/cplane_proxy_v1.rs +++ b/proxy/src/control_plane/client/cplane_proxy_v1.rs @@ -22,7 +22,7 @@ use crate::control_plane::errors::{ use crate::control_plane::locks::ApiLocks; use crate::control_plane::messages::{ColdStartInfo, EndpointJwksResponse, Reason}; use crate::control_plane::{ - AuthInfo, AuthSecret, CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedNodeInfo, + AuthInfo, AuthSecret, AccessBlockerFlags, CachedAccessBlockerFlags, CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret, NodeInfo, }; use crate::metrics::{CacheOutcome, Metrics}; @@ -164,8 +164,10 @@ impl NeonControlPlaneClient { allowed_vpc_endpoint_ids, project_id: body.project_id, account_id: body.account_id, - block_public_connections, - block_private_connections: block_vpc_connections, + access_blocker_flags: AccessBlockerFlags { + public_access_blocked: block_public_connections, + vpc_access_blocked: block_vpc_connections, + }, }) } .inspect_err(|e| tracing::debug!(error = ?e)) @@ -328,6 +330,11 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { normalized_ep_int, Arc::new(auth_info.allowed_vpc_endpoint_ids), ); + self.caches.project_info.insert_block_public_or_vpc_access( + project_id, + normalized_ep_int, + auth_info.access_blocker_flags, + ); ctx.set_project_id(project_id); } // When we just got a secret, we don't need to invalidate it. @@ -354,6 +361,7 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { let auth_info = self.do_get_auth_info(ctx, user_info).await?; let allowed_ips = Arc::new(auth_info.allowed_ips); let allowed_vpc_endpoint_ids = Arc::new(auth_info.allowed_vpc_endpoint_ids); + let access_blocker_flags = auth_info.access_blocker_flags; let user = &user_info.user; let account_id = auth_info.account_id; if let Some(project_id) = auth_info.project_id { @@ -375,6 +383,11 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { normalized_ep_int, allowed_vpc_endpoint_ids.clone(), ); + self.caches.project_info.insert_block_public_or_vpc_access( + project_id, + normalized_ep_int, + access_blocker_flags, + ); ctx.set_project_id(project_id); } Ok(Cached::new_uncached(allowed_ips)) @@ -406,6 +419,7 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { let auth_info = self.do_get_auth_info(ctx, user_info).await?; let allowed_ips = Arc::new(auth_info.allowed_ips); let allowed_vpc_endpoint_ids = Arc::new(auth_info.allowed_vpc_endpoint_ids); + let access_blocker_flags = auth_info.access_blocker_flags; let user = &user_info.user; let account_id = auth_info.account_id; if let Some(project_id) = auth_info.project_id { @@ -427,11 +441,74 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { normalized_ep_int, allowed_vpc_endpoint_ids.clone(), ); + self.caches.project_info.insert_block_public_or_vpc_access( + project_id, + normalized_ep_int, + access_blocker_flags, + ); ctx.set_project_id(project_id); } Ok(Cached::new_uncached(allowed_vpc_endpoint_ids)) } + async fn get_block_public_or_vpc_access( + &self, + ctx: &RequestContext, + user_info: &ComputeUserInfo, + ) -> Result { + let normalized_ep = &user_info.endpoint.normalize(); + if let Some(access_blocker_flags) = self + .caches + .project_info + .get_block_public_or_vpc_access(normalized_ep) + { + Metrics::get() + .proxy + .access_blocker_flags_cache_stats + .inc(CacheOutcome::Hit); + return Ok(access_blocker_flags); + } + + Metrics::get() + .proxy + .access_blocker_flags_cache_stats + .inc(CacheOutcome::Miss); + + let auth_info = self.do_get_auth_info(ctx, user_info).await?; + let allowed_ips = Arc::new(auth_info.allowed_ips); + let allowed_vpc_endpoint_ids = Arc::new(auth_info.allowed_vpc_endpoint_ids); + let access_blocker_flags = auth_info.access_blocker_flags; + let user = &user_info.user; + let account_id = auth_info.account_id; + if let Some(project_id) = auth_info.project_id { + let normalized_ep_int = normalized_ep.into(); + self.caches.project_info.insert_role_secret( + project_id, + normalized_ep_int, + user.into(), + auth_info.secret.clone(), + ); + self.caches.project_info.insert_allowed_ips( + project_id, + normalized_ep_int, + allowed_ips.clone(), + ); + self.caches.project_info.insert_allowed_vpc_endpoint_ids( + account_id, + project_id, + normalized_ep_int, + allowed_vpc_endpoint_ids.clone(), + ); + self.caches.project_info.insert_block_public_or_vpc_access( + project_id, + normalized_ep_int, + access_blocker_flags.clone(), + ); + ctx.set_project_id(project_id); + } + Ok(Cached::new_uncached(access_blocker_flags)) + } + #[tracing::instrument(skip_all)] async fn get_endpoint_jwks( &self, diff --git a/proxy/src/control_plane/client/mock.rs b/proxy/src/control_plane/client/mock.rs index f2c11a379ebc..6dec5719a6cf 100644 --- a/proxy/src/control_plane/client/mock.rs +++ b/proxy/src/control_plane/client/mock.rs @@ -126,6 +126,8 @@ impl MockControlPlane { // TODO allowed_vpc_endpoint_ids: vec![], project_id: None, + account_id: None, + access_blocker_flags: Default::default(), }) } @@ -240,6 +242,14 @@ impl super::ControlPlaneApi for MockControlPlane { ))) } + async fn get_block_public_or_vpc_access( + &self, + _ctx: &RequestContext, + user_info: &ComputeUserInfo, + ) -> Result { + Ok(Cached::new_uncached(self.do_get_auth_info(user_info).await?.access_blocker_flags)) + } + async fn get_endpoint_jwks( &self, _ctx: &RequestContext, diff --git a/proxy/src/control_plane/client/mod.rs b/proxy/src/control_plane/client/mod.rs index a6700a043a64..c577931a229f 100644 --- a/proxy/src/control_plane/client/mod.rs +++ b/proxy/src/control_plane/client/mod.rs @@ -17,7 +17,7 @@ use crate::cache::project_info::ProjectInfoCacheImpl; use crate::config::{CacheOptions, EndpointCacheConfig, ProjectInfoCacheOptions}; use crate::context::RequestContext; use crate::control_plane::{ - errors, CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret, + errors, CachedAccessBlockerFlags, CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret, ControlPlaneApi, NodeInfoCache, }; use crate::error::ReportableError; @@ -84,6 +84,20 @@ impl ControlPlaneApi for ControlPlaneClient { } } + async fn get_block_public_or_vpc_access( + &self, + ctx: &RequestContext, + user_info: &ComputeUserInfo, + ) -> Result { + match self { + Self::ProxyV1(api) => api.get_block_public_or_vpc_access(ctx, user_info).await, + #[cfg(any(test, feature = "testing"))] + Self::PostgresMock(api) => api.get_block_public_or_vpc_access(ctx, user_info).await, + #[cfg(test)] + Self::Test(api) => api.get_block_public_or_vpc_access(), + } + } + async fn get_endpoint_jwks( &self, ctx: &RequestContext, @@ -123,6 +137,10 @@ pub(crate) trait TestControlPlaneClient: Send + Sync + 'static { &self, ) -> Result; + fn get_block_public_or_vpc_access( + &self, + ) -> Result; + fn dyn_clone(&self) -> Box; } diff --git a/proxy/src/control_plane/mod.rs b/proxy/src/control_plane/mod.rs index a3db2b5e7b67..f92e4f3f6055 100644 --- a/proxy/src/control_plane/mod.rs +++ b/proxy/src/control_plane/mod.rs @@ -59,10 +59,8 @@ pub(crate) struct AuthInfo { pub(crate) project_id: Option, /// Account ID. This is used for cache invalidation. pub(crate) account_id: Option, - /// Are public connections blocked? - pub(crate) block_public_connections: bool, - /// Are private connections blocked? - pub(crate) block_private_connections: bool, + /// Are public connections or VPC connections blocked? + pub(crate) access_blocker_flags: AccessBlockerFlags, } /// Info for establishing a connection to a compute node. @@ -104,6 +102,12 @@ impl NodeInfo { } } +#[derive(Clone, Default, Eq, PartialEq, Debug)] +pub(crate) struct AccessBlockerFlags { + pub public_access_blocked: bool, + pub vpc_access_blocked: bool, +} + pub(crate) type NodeInfoCache = TimedLru>>; pub(crate) type CachedNodeInfo = Cached<&'static NodeInfoCache, NodeInfo>; @@ -111,6 +115,8 @@ pub(crate) type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option< pub(crate) type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc>>; pub(crate) type CachedAllowedVpcEndpointIds = Cached<&'static ProjectInfoCacheImpl, Arc>>; +pub(crate) type CachedAccessBlockerFlags = + Cached<&'static ProjectInfoCacheImpl, AccessBlockerFlags>; /// This will allocate per each call, but the http requests alone /// already require a few allocations, so it should be fine. @@ -136,6 +142,12 @@ pub(crate) trait ControlPlaneApi { user_info: &ComputeUserInfo, ) -> Result; + async fn get_block_public_or_vpc_access( + &self, + ctx: &RequestContext, + user_info: &ComputeUserInfo, + ) -> Result; + async fn get_endpoint_jwks( &self, ctx: &RequestContext, diff --git a/proxy/src/metrics.rs b/proxy/src/metrics.rs index e2fa927a4e5b..d5be93a747eb 100644 --- a/proxy/src/metrics.rs +++ b/proxy/src/metrics.rs @@ -99,6 +99,9 @@ pub struct ProxyMetrics { /// Number of cache hits/misses for VPC endpoint IDs. pub vpc_endpoint_id_cache_stats: CounterVec>, + /// Number of cache hits/misses for access blocker flags. + pub access_blocker_flags_cache_stats: CounterVec>, + /// Number of allowed VPC endpoints IDs #[metric(metadata = Thresholds::with_buckets([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0, 50.0, 100.0]))] pub allowed_vpc_endpoint_ids: Histogram<10>, diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index 5261033a51ee..e7103ba6eacb 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -536,6 +536,12 @@ impl TestControlPlaneClient for TestConnectMechanism { unimplemented!("not used in tests") } + fn get_block_public_or_vpc_access( + &self, + ) -> Result { + unimplemented!("not used in tests") + } + fn dyn_clone(&self) -> Box { Box::new(self.clone()) } diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index ffcafeb246b5..dc757a6d5b1d 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -60,25 +60,35 @@ impl PoolingBackend { let backend = self.auth_backend.as_ref().map(|()| user_info.clone()); let allowed_ips = backend.get_allowed_ips(ctx).await?; - let extra = ctx.extra(); - let incoming_endpoint_id = match extra { - None => "".to_string(), - Some(ConnectionInfoExtra::Aws { vpce_id }) => { - // Convert the vcpe_id to a string - match String::from_utf8(vpce_id.to_vec()) { - Ok(s) => s, - Err(_e) => "".to_string(), - } - } - Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(), - }; - if self.config.authentication_config.ip_allowlist_check_enabled && !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) { return Err(AuthError::ip_address_not_allowed(ctx.peer_addr())); } - if incoming_endpoint_id != "" { + + let access_blocker_flags = backend.get_block_public_or_vpc_access(ctx).await?; + if self.config.authentication_config.is_vpc_acccess_proxy { + if access_blocker_flags.vpc_access_blocked { + return Err(AuthError::NetworkNotAllowed); + } + + let extra = ctx.extra(); + let incoming_endpoint_id = match extra { + None => "".to_string(), + Some(ConnectionInfoExtra::Aws { vpce_id }) => { + // Convert the vcpe_id to a string + match String::from_utf8(vpce_id.to_vec()) { + Ok(s) => s, + Err(_e) => "".to_string(), + } + } + Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(), + }; + + if incoming_endpoint_id == "" { + return Err(AuthError::MissingVPCEndpointId); + } + let allowed_vpc_endpoint_ids = backend.get_allowed_vpc_endpoint_ids(ctx).await?; // TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that. if !allowed_vpc_endpoint_ids.is_empty() @@ -86,6 +96,10 @@ impl PoolingBackend { { return Err(AuthError::vpc_endpoint_id_not_allowed(incoming_endpoint_id)); } + } else { + if access_blocker_flags.public_access_blocked { + return Err(AuthError::NetworkNotAllowed); + } } if !self From 6edd61feb0d531a1c007be07e01d7945308c39fd Mon Sep 17 00:00:00 2001 From: Stefan Radig Date: Thu, 19 Dec 2024 15:11:52 +0100 Subject: [PATCH 10/16] Add check to console redirect proxy to block public access if configured. --- proxy/src/auth/backend/console_redirect.rs | 9 +++++++++ proxy/src/control_plane/messages.rs | 2 ++ 2 files changed, 11 insertions(+) diff --git a/proxy/src/auth/backend/console_redirect.rs b/proxy/src/auth/backend/console_redirect.rs index 1cbf91d3ae73..b1b09f0260b3 100644 --- a/proxy/src/auth/backend/console_redirect.rs +++ b/proxy/src/auth/backend/console_redirect.rs @@ -191,6 +191,15 @@ async fn authenticate( } } + // Check if the access over the public internet is allowed, otherwise block. Note that + // the console redirect is not behind the VPC service endpoint, so we don't need to check + // the VPC endpoint ID. + if let Some(public_access_allowed) = db_info.public_access_allowed { + if !public_access_allowed { + return Err(auth::AuthError::NetworkNotAllowed); + } + } + client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?; // This config should be self-contained, because we won't diff --git a/proxy/src/control_plane/messages.rs b/proxy/src/control_plane/messages.rs index d75b51d3078d..5883d02b92c7 100644 --- a/proxy/src/control_plane/messages.rs +++ b/proxy/src/control_plane/messages.rs @@ -287,6 +287,8 @@ pub(crate) struct DatabaseInfo { pub(crate) allowed_ips: Option>, #[serde(default)] pub(crate) allowed_vpc_endpoint_ids: Option>, + #[serde(default)] + pub(crate) public_access_allowed: Option, } // Manually implement debug to omit sensitive info. From 1a8b251773fdee071deadd153edf197b50708479 Mon Sep 17 00:00:00 2001 From: Stefan Radig Date: Thu, 19 Dec 2024 16:05:00 +0100 Subject: [PATCH 11/16] Formatting fixes --- proxy/src/auth/backend/mod.rs | 23 +++++++++++-------- proxy/src/auth/mod.rs | 4 +--- .../control_plane/client/cplane_proxy_v1.rs | 4 ++-- proxy/src/control_plane/client/mock.rs | 14 ++++++----- proxy/src/control_plane/client/mod.rs | 12 +++++----- proxy/src/metrics.rs | 2 +- proxy/src/proxy/tests/mod.rs | 5 ++-- proxy/src/serverless/backend.rs | 2 +- 8 files changed, 36 insertions(+), 30 deletions(-) diff --git a/proxy/src/auth/backend/mod.rs b/proxy/src/auth/backend/mod.rs index ad98d66c93fd..c61544dda72c 100644 --- a/proxy/src/auth/backend/mod.rs +++ b/proxy/src/auth/backend/mod.rs @@ -26,8 +26,8 @@ use crate::context::RequestContext; use crate::control_plane::client::ControlPlaneClient; use crate::control_plane::errors::GetAuthInfoError; use crate::control_plane::{ - self, AuthSecret, CachedAccessBlockerFlags, CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedNodeInfo, - CachedRoleSecret, ControlPlaneApi, + self, AuthSecret, CachedAccessBlockerFlags, CachedAllowedIps, CachedAllowedVpcEndpointIds, + CachedNodeInfo, CachedRoleSecret, ControlPlaneApi, }; use crate::intern::EndpointIdInt; use crate::metrics::Metrics; @@ -321,7 +321,9 @@ async fn auth_quirks( if !allowed_vpc_endpoint_ids.is_empty() && !allowed_vpc_endpoint_ids.contains(&incoming_vpc_endpoint_id) { - return Err(AuthError::vpc_endpoint_id_not_allowed(incoming_vpc_endpoint_id)); + return Err(AuthError::vpc_endpoint_id_not_allowed( + incoming_vpc_endpoint_id, + )); } } else { if access_blocks.public_access_blocked { @@ -501,7 +503,9 @@ impl Backend<'_, ComputeUserInfo> { ctx: &RequestContext, ) -> Result { match self { - Self::ControlPlane(api, user_info) => api.get_block_public_or_vpc_access(ctx, user_info).await, + Self::ControlPlane(api, user_info) => { + api.get_block_public_or_vpc_access(ctx, user_info).await + } Self::Local(_) => Ok(Cached::new_uncached(Default::default())), } } @@ -569,7 +573,8 @@ mod tests { use crate::config::AuthenticationConfig; use crate::context::RequestContext; use crate::control_plane::{ - self, AccessBlockerFlags, CachedAccessBlockerFlags, CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret, + self, AccessBlockerFlags, CachedAccessBlockerFlags, CachedAllowedIps, + CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret, }; use crate::proxy::NeonOptions; use crate::rate_limiter::{EndpointRateLimiter, RateBucketInfo}; @@ -612,12 +617,12 @@ mod tests { } async fn get_block_public_or_vpc_access( - &self, - _ctx: &RequestContext, - _user_info: &super::ComputeUserInfo, + &self, + _ctx: &RequestContext, + _user_info: &super::ComputeUserInfo, ) -> Result { Ok(CachedAccessBlockerFlags::new_uncached( - self.access_blocker_flags.clone() + self.access_blocker_flags.clone(), )) } diff --git a/proxy/src/auth/mod.rs b/proxy/src/auth/mod.rs index 6ef40b3cc603..6082695a6b1b 100644 --- a/proxy/src/auth/mod.rs +++ b/proxy/src/auth/mod.rs @@ -75,9 +75,7 @@ pub(crate) enum AuthError { )] IpAddressNotAllowed(IpAddr), - #[error( - "This connection is trying to access this endpoint from a blocked network." - )] + #[error("This connection is trying to access this endpoint from a blocked network.")] NetworkNotAllowed, #[error( diff --git a/proxy/src/control_plane/client/cplane_proxy_v1.rs b/proxy/src/control_plane/client/cplane_proxy_v1.rs index 1d170f5ef468..ef6621fc598a 100644 --- a/proxy/src/control_plane/client/cplane_proxy_v1.rs +++ b/proxy/src/control_plane/client/cplane_proxy_v1.rs @@ -22,8 +22,8 @@ use crate::control_plane::errors::{ use crate::control_plane::locks::ApiLocks; use crate::control_plane::messages::{ColdStartInfo, EndpointJwksResponse, Reason}; use crate::control_plane::{ - AuthInfo, AuthSecret, AccessBlockerFlags, CachedAccessBlockerFlags, CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedNodeInfo, - CachedRoleSecret, NodeInfo, + AccessBlockerFlags, AuthInfo, AuthSecret, CachedAccessBlockerFlags, CachedAllowedIps, + CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret, NodeInfo, }; use crate::metrics::{CacheOutcome, Metrics}; use crate::rate_limiter::WakeComputeRateLimiter; diff --git a/proxy/src/control_plane/client/mock.rs b/proxy/src/control_plane/client/mock.rs index 6dec5719a6cf..b97dcd8e8757 100644 --- a/proxy/src/control_plane/client/mock.rs +++ b/proxy/src/control_plane/client/mock.rs @@ -243,12 +243,14 @@ impl super::ControlPlaneApi for MockControlPlane { } async fn get_block_public_or_vpc_access( - &self, - _ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - Ok(Cached::new_uncached(self.do_get_auth_info(user_info).await?.access_blocker_flags)) - } + &self, + _ctx: &RequestContext, + user_info: &ComputeUserInfo, + ) -> Result { + Ok(Cached::new_uncached( + self.do_get_auth_info(user_info).await?.access_blocker_flags, + )) + } async fn get_endpoint_jwks( &self, diff --git a/proxy/src/control_plane/client/mod.rs b/proxy/src/control_plane/client/mod.rs index c577931a229f..a06943726e50 100644 --- a/proxy/src/control_plane/client/mod.rs +++ b/proxy/src/control_plane/client/mod.rs @@ -17,8 +17,8 @@ use crate::cache::project_info::ProjectInfoCacheImpl; use crate::config::{CacheOptions, EndpointCacheConfig, ProjectInfoCacheOptions}; use crate::context::RequestContext; use crate::control_plane::{ - errors, CachedAccessBlockerFlags, CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret, - ControlPlaneApi, NodeInfoCache, + errors, CachedAccessBlockerFlags, CachedAllowedIps, CachedAllowedVpcEndpointIds, + CachedNodeInfo, CachedRoleSecret, ControlPlaneApi, NodeInfoCache, }; use crate::error::ReportableError; use crate::metrics::ApiLockMetrics; @@ -85,10 +85,10 @@ impl ControlPlaneApi for ControlPlaneClient { } async fn get_block_public_or_vpc_access( - &self, - ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { + &self, + ctx: &RequestContext, + user_info: &ComputeUserInfo, + ) -> Result { match self { Self::ProxyV1(api) => api.get_block_public_or_vpc_access(ctx, user_info).await, #[cfg(any(test, feature = "testing"))] diff --git a/proxy/src/metrics.rs b/proxy/src/metrics.rs index d5be93a747eb..25bcc81108b6 100644 --- a/proxy/src/metrics.rs +++ b/proxy/src/metrics.rs @@ -101,7 +101,7 @@ pub struct ProxyMetrics { /// Number of cache hits/misses for access blocker flags. pub access_blocker_flags_cache_stats: CounterVec>, - + /// Number of allowed VPC endpoints IDs #[metric(metadata = Thresholds::with_buckets([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0, 50.0, 100.0]))] pub allowed_vpc_endpoint_ids: Histogram<10>, diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index e7103ba6eacb..d8c00a9b4177 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -538,10 +538,11 @@ impl TestControlPlaneClient for TestConnectMechanism { fn get_block_public_or_vpc_access( &self, - ) -> Result { + ) -> Result + { unimplemented!("not used in tests") } - + fn dyn_clone(&self) -> Box { Box::new(self.clone()) } diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index dc757a6d5b1d..cac10644ea76 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -84,7 +84,7 @@ impl PoolingBackend { } Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(), }; - + if incoming_endpoint_id == "" { return Err(AuthError::MissingVPCEndpointId); } From 2704293bb407431ca022f1e38e44771e853d9d03 Mon Sep 17 00:00:00 2001 From: Stefan Radig Date: Mon, 23 Dec 2024 15:56:39 +0100 Subject: [PATCH 12/16] Addresses review comments from Ivan --- proxy/src/auth/backend/mod.rs | 6 +----- proxy/src/control_plane/client/mock.rs | 1 - 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/proxy/src/auth/backend/mod.rs b/proxy/src/auth/backend/mod.rs index c61544dda72c..5cd17b276e85 100644 --- a/proxy/src/auth/backend/mod.rs +++ b/proxy/src/auth/backend/mod.rs @@ -302,7 +302,7 @@ async fn auth_quirks( } let extra = ctx.extra(); let incoming_vpc_endpoint_id = match extra { - None => "".to_string(), + None => return Err(AuthError::MissingEndpointName), Some(ConnectionInfoExtra::Aws { vpce_id }) => { // Convert the vcpe_id to a string match String::from_utf8(vpce_id.to_vec()) { @@ -312,10 +312,6 @@ async fn auth_quirks( } Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(), }; - if incoming_vpc_endpoint_id == "" { - // This should never happen, would be a setup error on our side. - return Err(AuthError::MissingEndpointName); - } let allowed_vpc_endpoint_ids = api.get_allowed_vpc_endpoint_ids(ctx, &info).await?; // TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that. if !allowed_vpc_endpoint_ids.is_empty() diff --git a/proxy/src/control_plane/client/mock.rs b/proxy/src/control_plane/client/mock.rs index b97dcd8e8757..fbf4de794896 100644 --- a/proxy/src/control_plane/client/mock.rs +++ b/proxy/src/control_plane/client/mock.rs @@ -123,7 +123,6 @@ impl MockControlPlane { Ok(AuthInfo { secret, allowed_ips, - // TODO allowed_vpc_endpoint_ids: vec![], project_id: None, account_id: None, From d45319ac0825596246975b711ce5462d706ff5ba Mon Sep 17 00:00:00 2001 From: Stefan Radig Date: Fri, 3 Jan 2025 14:26:19 +0100 Subject: [PATCH 13/16] Address clippy comments. --- proxy/src/auth/backend/mod.rs | 14 ++++++-------- proxy/src/control_plane/client/mock.rs | 4 ++-- proxy/src/serverless/backend.rs | 12 +++++------- 3 files changed, 13 insertions(+), 17 deletions(-) diff --git a/proxy/src/auth/backend/mod.rs b/proxy/src/auth/backend/mod.rs index 5cd17b276e85..b24266d3b737 100644 --- a/proxy/src/auth/backend/mod.rs +++ b/proxy/src/auth/backend/mod.rs @@ -26,8 +26,8 @@ use crate::context::RequestContext; use crate::control_plane::client::ControlPlaneClient; use crate::control_plane::errors::GetAuthInfoError; use crate::control_plane::{ - self, AuthSecret, CachedAccessBlockerFlags, CachedAllowedIps, CachedAllowedVpcEndpointIds, - CachedNodeInfo, CachedRoleSecret, ControlPlaneApi, + self, AccessBlockerFlags, AuthSecret, CachedAccessBlockerFlags, CachedAllowedIps, + CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret, ControlPlaneApi, }; use crate::intern::EndpointIdInt; use crate::metrics::Metrics; @@ -307,7 +307,7 @@ async fn auth_quirks( // Convert the vcpe_id to a string match String::from_utf8(vpce_id.to_vec()) { Ok(s) => s, - Err(_e) => "".to_string(), + Err(_e) => String::new(), } } Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(), @@ -321,10 +321,8 @@ async fn auth_quirks( incoming_vpc_endpoint_id, )); } - } else { - if access_blocks.public_access_blocked { - return Err(AuthError::NetworkNotAllowed); - } + } else if access_blocks.public_access_blocked { + return Err(AuthError::NetworkNotAllowed); } if !endpoint_rate_limiter.check(info.endpoint.clone().into(), 1) { @@ -502,7 +500,7 @@ impl Backend<'_, ComputeUserInfo> { Self::ControlPlane(api, user_info) => { api.get_block_public_or_vpc_access(ctx, user_info).await } - Self::Local(_) => Ok(Cached::new_uncached(Default::default())), + Self::Local(_) => Ok(Cached::new_uncached(AccessBlockerFlags::default())), } } } diff --git a/proxy/src/control_plane/client/mock.rs b/proxy/src/control_plane/client/mock.rs index fbf4de794896..1e6cde8fb080 100644 --- a/proxy/src/control_plane/client/mock.rs +++ b/proxy/src/control_plane/client/mock.rs @@ -20,7 +20,7 @@ use crate::control_plane::errors::{ ControlPlaneError, GetAuthInfoError, GetEndpointJwksError, WakeComputeError, }; use crate::control_plane::messages::MetricsAuxInfo; -use crate::control_plane::{AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo}; +use crate::control_plane::{AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo}; use crate::error::io_error; use crate::intern::RoleNameInt; use crate::types::{BranchId, EndpointId, ProjectId, RoleName}; @@ -126,7 +126,7 @@ impl MockControlPlane { allowed_vpc_endpoint_ids: vec![], project_id: None, account_id: None, - access_blocker_flags: Default::default(), + access_blocker_flags: AccessBlockerFlags::default(), }) } diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index cac10644ea76..ba036364102b 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -74,18 +74,18 @@ impl PoolingBackend { let extra = ctx.extra(); let incoming_endpoint_id = match extra { - None => "".to_string(), + None => String::new(), Some(ConnectionInfoExtra::Aws { vpce_id }) => { // Convert the vcpe_id to a string match String::from_utf8(vpce_id.to_vec()) { Ok(s) => s, - Err(_e) => "".to_string(), + Err(_e) => String::new(), } } Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(), }; - if incoming_endpoint_id == "" { + if incoming_endpoint_id.is_empty() { return Err(AuthError::MissingVPCEndpointId); } @@ -96,10 +96,8 @@ impl PoolingBackend { { return Err(AuthError::vpc_endpoint_id_not_allowed(incoming_endpoint_id)); } - } else { - if access_blocker_flags.public_access_blocked { - return Err(AuthError::NetworkNotAllowed); - } + } else if access_blocker_flags.public_access_blocked { + return Err(AuthError::NetworkNotAllowed); } if !self From 1dfc754248d246908cf525b76567b395d623dd0d Mon Sep 17 00:00:00 2001 From: Ivan Efremov Date: Tue, 21 Jan 2025 11:28:31 +0200 Subject: [PATCH 14/16] Fix merge conflict and clippy --- proxy/src/auth/backend/console_redirect.rs | 4 ++-- proxy/src/auth/backend/mod.rs | 11 ++++------- proxy/src/redis/notifications.rs | 7 ++----- proxy/src/serverless/backend.rs | 5 +---- 4 files changed, 9 insertions(+), 18 deletions(-) diff --git a/proxy/src/auth/backend/console_redirect.rs b/proxy/src/auth/backend/console_redirect.rs index b1b09f0260b3..ebd6f0e60dd4 100644 --- a/proxy/src/auth/backend/console_redirect.rs +++ b/proxy/src/auth/backend/console_redirect.rs @@ -92,9 +92,9 @@ impl BackendIpAllowlist for ConsoleRedirectBackend { user_info: &ComputeUserInfo, ) -> auth::Result> { self.api - .get_allowed_ips_and_secret(ctx, user_info) + .get_allowed_ips(ctx, user_info) .await - .map(|(ips, _)| ips.as_ref().clone()) + .map(|ips| ips.as_ref().clone()) .map_err(|e| e.into()) } } diff --git a/proxy/src/auth/backend/mod.rs b/proxy/src/auth/backend/mod.rs index b24266d3b737..0f4888527231 100644 --- a/proxy/src/auth/backend/mod.rs +++ b/proxy/src/auth/backend/mod.rs @@ -305,10 +305,7 @@ async fn auth_quirks( None => return Err(AuthError::MissingEndpointName), Some(ConnectionInfoExtra::Aws { vpce_id }) => { // Convert the vcpe_id to a string - match String::from_utf8(vpce_id.to_vec()) { - Ok(s) => s, - Err(_e) => String::new(), - } + String::from_utf8(vpce_id.to_vec()).unwrap_or_default() } Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(), }; @@ -513,12 +510,12 @@ impl BackendIpAllowlist for Backend<'_, ()> { user_info: &ComputeUserInfo, ) -> auth::Result> { let auth_data = match self { - Self::ControlPlane(api, ()) => api.get_allowed_ips_and_secret(ctx, user_info).await, - Self::Local(_) => Ok((Cached::new_uncached(Arc::new(vec![])), None)), + Self::ControlPlane(api, ()) => api.get_allowed_ips(ctx, user_info).await, + Self::Local(_) => Ok(Cached::new_uncached(Arc::new(vec![]))), }; auth_data - .map(|(ips, _)| ips.as_ref().clone()) + .map(|ips| ips.as_ref().clone()) .map_err(|e| e.into()) } } diff --git a/proxy/src/redis/notifications.rs b/proxy/src/redis/notifications.rs index 9bd9fecb3b17..84e3c580a9f4 100644 --- a/proxy/src/redis/notifications.rs +++ b/proxy/src/redis/notifications.rs @@ -211,10 +211,7 @@ impl MessageHandler { .proxy .redis_events_count .inc(RedisEventsCount::AllowedVpcEndpointIdsUpdateForProjects); - } else if matches!( - msg, - Notification::AllowedVpcEndpointsUpdatedForOrg { .. } - ) { + } else if matches!(msg, Notification::AllowedVpcEndpointsUpdatedForOrg { .. }) { Metrics::get() .proxy .redis_events_count @@ -248,7 +245,7 @@ fn invalidate_cache(cache: Arc, msg: Notification) { match msg { Notification::AllowedIpsUpdate { allowed_ips_update } => { cache.invalidate_allowed_ips_for_project(allowed_ips_update.project_id); - }, + } Notification::BlockPublicOrVpcAccessUpdated { block_public_or_vpc_access_updated, } => cache.invalidate_block_public_or_vpc_access_for_project( diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index ba036364102b..0fb4a8a6cc70 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -77,10 +77,7 @@ impl PoolingBackend { None => String::new(), Some(ConnectionInfoExtra::Aws { vpce_id }) => { // Convert the vcpe_id to a string - match String::from_utf8(vpce_id.to_vec()) { - Ok(s) => s, - Err(_e) => String::new(), - } + String::from_utf8(vpce_id.to_vec()).unwrap_or_default() } Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(), }; From fdc0c78c4f51b286c5ded1058a8f57e9baf1123d Mon Sep 17 00:00:00 2001 From: Ivan Efremov Date: Mon, 27 Jan 2025 10:40:10 +0200 Subject: [PATCH 15/16] Add to the cancellation VPC logic --- proxy/src/auth/backend/console_redirect.rs | 23 +++------ proxy/src/auth/backend/mod.rs | 38 ++++----------- proxy/src/cancellation.rs | 55 +++++++++++++++++++--- proxy/src/console_redirect_proxy.rs | 3 +- proxy/src/proxy/mod.rs | 3 +- proxy/src/redis/notifications.rs | 4 +- 6 files changed, 69 insertions(+), 57 deletions(-) diff --git a/proxy/src/auth/backend/console_redirect.rs b/proxy/src/auth/backend/console_redirect.rs index ebd6f0e60dd4..9be29c38c938 100644 --- a/proxy/src/auth/backend/console_redirect.rs +++ b/proxy/src/auth/backend/console_redirect.rs @@ -7,8 +7,8 @@ use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{info, info_span}; -use super::{ComputeCredentialKeys, ControlPlaneApi}; -use crate::auth::backend::{BackendIpAllowlist, ComputeUserInfo}; +use super::ComputeCredentialKeys; +use crate::auth::backend::ComputeUserInfo; use crate::auth::IpPattern; use crate::cache::Cached; use crate::config::AuthenticationConfig; @@ -84,26 +84,15 @@ pub(crate) fn new_psql_session_id() -> String { hex::encode(rand::random::<[u8; 8]>()) } -#[async_trait] -impl BackendIpAllowlist for ConsoleRedirectBackend { - async fn get_allowed_ips( - &self, - ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> auth::Result> { - self.api - .get_allowed_ips(ctx, user_info) - .await - .map(|ips| ips.as_ref().clone()) - .map_err(|e| e.into()) - } -} - impl ConsoleRedirectBackend { pub fn new(console_uri: reqwest::Url, api: cplane_proxy_v1::NeonControlPlaneClient) -> Self { Self { console_uri, api } } + pub(crate) fn get_api(&self) -> &cplane_proxy_v1::NeonControlPlaneClient { + &self.api + } + pub(crate) async fn authenticate( &self, ctx: &RequestContext, diff --git a/proxy/src/auth/backend/mod.rs b/proxy/src/auth/backend/mod.rs index 0f4888527231..892bc74c2134 100644 --- a/proxy/src/auth/backend/mod.rs +++ b/proxy/src/auth/backend/mod.rs @@ -101,6 +101,13 @@ impl Backend<'_, T> { Self::Local(l) => Backend::Local(MaybeOwned::Borrowed(l)), } } + + pub(crate) fn get_api(&self) -> &ControlPlaneClient { + match self { + Self::ControlPlane(api, _) => api, + Self::Local(_) => panic!("Local backend has no API"), + } + } } impl<'a, T> Backend<'a, T> { @@ -249,15 +256,6 @@ impl AuthenticationConfig { } } -#[async_trait::async_trait] -pub(crate) trait BackendIpAllowlist { - async fn get_allowed_ips( - &self, - ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> auth::Result>; -} - /// True to its name, this function encapsulates our current auth trade-offs. /// Here, we choose the appropriate auth flow based on circumstances. /// @@ -300,8 +298,8 @@ async fn auth_quirks( if access_blocks.vpc_access_blocked { return Err(AuthError::NetworkNotAllowed); } - let extra = ctx.extra(); - let incoming_vpc_endpoint_id = match extra { + + let incoming_vpc_endpoint_id = match ctx.extra() { None => return Err(AuthError::MissingEndpointName), Some(ConnectionInfoExtra::Aws { vpce_id }) => { // Convert the vcpe_id to a string @@ -502,24 +500,6 @@ impl Backend<'_, ComputeUserInfo> { } } -#[async_trait::async_trait] -impl BackendIpAllowlist for Backend<'_, ()> { - async fn get_allowed_ips( - &self, - ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> auth::Result> { - let auth_data = match self { - Self::ControlPlane(api, ()) => api.get_allowed_ips(ctx, user_info).await, - Self::Local(_) => Ok(Cached::new_uncached(Arc::new(vec![]))), - }; - - auth_data - .map(|ips| ips.as_ref().clone()) - .map_err(|e| e.into()) - } -} - #[async_trait::async_trait] impl ComputeConnectBackend for Backend<'_, ComputeCredentials> { async fn wake_compute( diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index 9a0b954341bb..4d919f374a2d 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -12,13 +12,15 @@ use tokio::net::TcpStream; use tokio::sync::{mpsc, oneshot}; use tracing::{debug, info}; -use crate::auth::backend::{BackendIpAllowlist, ComputeUserInfo}; +use crate::auth::backend::ComputeUserInfo; use crate::auth::{check_peer_addr_is_in_list, AuthError}; use crate::config::ComputeConfig; use crate::context::RequestContext; +use crate::control_plane::ControlPlaneApi; use crate::error::ReportableError; use crate::ext::LockExt; use crate::metrics::{CancelChannelSizeGuard, CancellationRequest, Metrics, RedisMsgKind}; +use crate::protocol2::ConnectionInfoExtra; use crate::rate_limiter::LeakyBucketRateLimiter; use crate::redis::keys::KeyPrefix; use crate::redis::kv_ops::RedisKVClient; @@ -133,6 +135,9 @@ pub(crate) enum CancelError { #[error("IP is not allowed")] IpNotAllowed, + #[error("VPC endpoint id is not allowed to connect")] + VpcEndpointIdNotAllowed, + #[error("Authentication backend error")] AuthError(#[from] AuthError), @@ -152,8 +157,9 @@ impl ReportableError for CancelError { } CancelError::Postgres(_) => crate::error::ErrorKind::Compute, CancelError::RateLimit => crate::error::ErrorKind::RateLimit, - CancelError::IpNotAllowed => crate::error::ErrorKind::User, - CancelError::NotFound => crate::error::ErrorKind::User, + CancelError::IpNotAllowed + | CancelError::VpcEndpointIdNotAllowed + | CancelError::NotFound => crate::error::ErrorKind::User, CancelError::AuthError(_) => crate::error::ErrorKind::ControlPlane, CancelError::InternalError => crate::error::ErrorKind::Service, } @@ -265,11 +271,12 @@ impl CancellationHandler { /// Will fetch IP allowlist internally. /// /// return Result primarily for tests - pub(crate) async fn cancel_session( + pub(crate) async fn cancel_session( &self, key: CancelKeyData, ctx: RequestContext, - check_allowed: bool, + check_ip_allowed: bool, + check_vpc_allowed: bool, auth_backend: &T, ) -> Result<(), CancelError> { let subnet_key = match ctx.peer_addr() { @@ -304,11 +311,11 @@ impl CancellationHandler { return Err(CancelError::NotFound); }; - if check_allowed { + if check_ip_allowed { let ip_allowlist = auth_backend .get_allowed_ips(&ctx, &cancel_closure.user_info) .await - .map_err(CancelError::AuthError)?; + .map_err(|e| CancelError::AuthError(e.into()))?; if !check_peer_addr_is_in_list(&ctx.peer_addr(), &ip_allowlist) { // log it here since cancel_session could be spawned in a task @@ -320,6 +327,40 @@ impl CancellationHandler { } } + // check if a VPC endpoint ID is coming in and if yes, if it's allowed + let access_blocks = auth_backend + .get_block_public_or_vpc_access(&ctx, &cancel_closure.user_info) + .await + .map_err(|e| CancelError::AuthError(e.into()))?; + + if check_vpc_allowed { + if access_blocks.vpc_access_blocked { + return Err(CancelError::AuthError(AuthError::NetworkNotAllowed)); + } + + let incoming_vpc_endpoint_id = match ctx.extra() { + None => return Err(CancelError::AuthError(AuthError::MissingVPCEndpointId)), + Some(ConnectionInfoExtra::Aws { vpce_id }) => { + // Convert the vcpe_id to a string + String::from_utf8(vpce_id.to_vec()).unwrap_or_default() + } + Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(), + }; + + let allowed_vpc_endpoint_ids = auth_backend + .get_allowed_vpc_endpoint_ids(&ctx, &cancel_closure.user_info) + .await + .map_err(|e| CancelError::AuthError(e.into()))?; + // TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that. + if !allowed_vpc_endpoint_ids.is_empty() + && !allowed_vpc_endpoint_ids.contains(&incoming_vpc_endpoint_id) + { + return Err(CancelError::VpcEndpointIdNotAllowed); + } + } else if access_blocks.public_access_blocked { + return Err(CancelError::VpcEndpointIdNotAllowed); + } + Metrics::get() .proxy .cancellation_requests_total diff --git a/proxy/src/console_redirect_proxy.rs b/proxy/src/console_redirect_proxy.rs index 78bfb6deacc3..c4548a7ddd95 100644 --- a/proxy/src/console_redirect_proxy.rs +++ b/proxy/src/console_redirect_proxy.rs @@ -182,7 +182,8 @@ pub(crate) async fn handle_client( cancel_key_data, ctx, config.authentication_config.ip_allowlist_check_enabled, - backend, + config.authentication_config.is_vpc_acccess_proxy, + backend.get_api(), ) .await .inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok(); diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index ab173bd0d052..8a407c811971 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -283,7 +283,8 @@ pub(crate) async fn handle_client( cancel_key_data, ctx, config.authentication_config.ip_allowlist_check_enabled, - auth_backend, + config.authentication_config.is_vpc_acccess_proxy, + auth_backend.get_api(), ) .await .inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok(); diff --git a/proxy/src/redis/notifications.rs b/proxy/src/redis/notifications.rs index 84e3c580a9f4..1a7024588aa1 100644 --- a/proxy/src/redis/notifications.rs +++ b/proxy/src/redis/notifications.rs @@ -10,7 +10,7 @@ use uuid::Uuid; use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider; use crate::cache::project_info::ProjectInfoCache; -use crate::intern::{ProjectIdInt, RoleNameInt, AccountIdInt}; +use crate::intern::{AccountIdInt, ProjectIdInt, RoleNameInt}; use crate::metrics::{Metrics, RedisErrors, RedisEventsCount}; const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates"; @@ -265,7 +265,7 @@ fn invalidate_cache(cache: Arc, msg: Notification) { .invalidate_role_secret_for_project( password_update.project_id, password_update.role_name, - ), + ), Notification::UnknownTopic => unreachable!(), } } From f4a331d24d9500170a9263ed00fbd34f1f32e0a4 Mon Sep 17 00:00:00 2001 From: Ivan Efremov Date: Fri, 31 Jan 2025 21:04:37 +0200 Subject: [PATCH 16/16] Fix ip allowlist fetch in auth --- proxy/src/auth/backend/mod.rs | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/proxy/src/auth/backend/mod.rs b/proxy/src/auth/backend/mod.rs index 892bc74c2134..7ef096207aed 100644 --- a/proxy/src/auth/backend/mod.rs +++ b/proxy/src/auth/backend/mod.rs @@ -283,14 +283,17 @@ async fn auth_quirks( }; debug!("fetching authentication info and allowlists"); - let allowed_ips = api.get_allowed_ips(ctx, &info).await?; // check allowed list - if config.ip_allowlist_check_enabled - && !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) - { - return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr())); - } + let allowed_ips = if config.ip_allowlist_check_enabled { + let allowed_ips = api.get_allowed_ips(ctx, &info).await?; + if !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) { + return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr())); + } + allowed_ips + } else { + Cached::new_uncached(Arc::new(vec![])) + }; // check if a VPC endpoint ID is coming in and if yes, if it's allowed let access_blocks = api.get_block_public_or_vpc_access(ctx, &info).await?;