Skip to content

Commit

Permalink
Add to the cancellation VPC logic
Browse files Browse the repository at this point in the history
  • Loading branch information
awarus committed Jan 29, 2025
1 parent 81ee691 commit 83c48ec
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 66 deletions.
23 changes: 6 additions & 17 deletions proxy/src/auth/backend/console_redirect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{info, info_span};

use super::{ComputeCredentialKeys, ControlPlaneApi};
use crate::auth::backend::{BackendIpAllowlist, ComputeUserInfo};
use super::ComputeCredentialKeys;
use crate::auth::backend::ComputeUserInfo;
use crate::auth::IpPattern;
use crate::cache::Cached;
use crate::config::AuthenticationConfig;
Expand Down Expand Up @@ -84,26 +84,15 @@ pub(crate) fn new_psql_session_id() -> String {
hex::encode(rand::random::<[u8; 8]>())
}

#[async_trait]
impl BackendIpAllowlist for ConsoleRedirectBackend {
async fn get_allowed_ips(
&self,
ctx: &RequestContext,
user_info: &ComputeUserInfo,
) -> auth::Result<Vec<auth::IpPattern>> {
self.api
.get_allowed_ips(ctx, user_info)
.await
.map(|ips| ips.as_ref().clone())
.map_err(|e| e.into())
}
}

impl ConsoleRedirectBackend {
pub fn new(console_uri: reqwest::Url, api: cplane_proxy_v1::NeonControlPlaneClient) -> Self {
Self { console_uri, api }
}

pub(crate) fn get_api(&self) -> &cplane_proxy_v1::NeonControlPlaneClient {
&self.api
}

pub(crate) async fn authenticate(
&self,
ctx: &RequestContext,
Expand Down
38 changes: 9 additions & 29 deletions proxy/src/auth/backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,13 @@ impl<T> Backend<'_, T> {
Self::Local(l) => Backend::Local(MaybeOwned::Borrowed(l)),
}
}

pub(crate) fn get_api(&self) -> &ControlPlaneClient {
match self {
Self::ControlPlane(api, _) => api,
Self::Local(_) => panic!("Local backend has no API"),
}
}
}

impl<'a, T> Backend<'a, T> {
Expand Down Expand Up @@ -249,15 +256,6 @@ impl AuthenticationConfig {
}
}

#[async_trait::async_trait]
pub(crate) trait BackendIpAllowlist {
async fn get_allowed_ips(
&self,
ctx: &RequestContext,
user_info: &ComputeUserInfo,
) -> auth::Result<Vec<auth::IpPattern>>;
}

/// True to its name, this function encapsulates our current auth trade-offs.
/// Here, we choose the appropriate auth flow based on circumstances.
///
Expand Down Expand Up @@ -300,8 +298,8 @@ async fn auth_quirks(
if access_blocks.vpc_access_blocked {
return Err(AuthError::NetworkNotAllowed);
}
let extra = ctx.extra();
let incoming_vpc_endpoint_id = match extra {

let incoming_vpc_endpoint_id = match ctx.extra() {
None => return Err(AuthError::MissingEndpointName),
Some(ConnectionInfoExtra::Aws { vpce_id }) => {
// Convert the vcpe_id to a string
Expand Down Expand Up @@ -502,24 +500,6 @@ impl Backend<'_, ComputeUserInfo> {
}
}

#[async_trait::async_trait]
impl BackendIpAllowlist for Backend<'_, ()> {
async fn get_allowed_ips(
&self,
ctx: &RequestContext,
user_info: &ComputeUserInfo,
) -> auth::Result<Vec<auth::IpPattern>> {
let auth_data = match self {
Self::ControlPlane(api, ()) => api.get_allowed_ips(ctx, user_info).await,
Self::Local(_) => Ok(Cached::new_uncached(Arc::new(vec![]))),
};

auth_data
.map(|ips| ips.as_ref().clone())
.map_err(|e| e.into())
}
}

#[async_trait::async_trait]
impl ComputeConnectBackend for Backend<'_, ComputeCredentials> {
async fn wake_compute(
Expand Down
55 changes: 48 additions & 7 deletions proxy/src/cancellation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,16 @@ use tokio::net::TcpStream;
use tokio::sync::mpsc;
use tracing::{debug, info};

use crate::auth::backend::{BackendIpAllowlist, ComputeUserInfo};
use crate::auth::backend::ComputeUserInfo;
use crate::auth::{check_peer_addr_is_in_list, AuthError};
use crate::config::ComputeConfig;
use crate::context::RequestContext;
use crate::control_plane::ControlPlaneApi;
use crate::error::ReportableError;
use crate::ext::LockExt;
use crate::metrics::CancelChannelSizeGuard;
use crate::metrics::{CancellationRequest, Metrics, RedisMsgKind};
use crate::protocol2::ConnectionInfoExtra;
use crate::rate_limiter::LeakyBucketRateLimiter;
use crate::redis::keys::KeyPrefix;
use crate::redis::kv_ops::RedisKVClient;
Expand Down Expand Up @@ -135,6 +137,9 @@ pub(crate) enum CancelError {
#[error("IP is not allowed")]
IpNotAllowed,

#[error("VPC endpoint id is not allowed to connect")]
VpcEndpointIdNotAllowed,

#[error("Authentication backend error")]
AuthError(#[from] AuthError),

Expand All @@ -154,8 +159,9 @@ impl ReportableError for CancelError {
}
CancelError::Postgres(_) => crate::error::ErrorKind::Compute,
CancelError::RateLimit => crate::error::ErrorKind::RateLimit,
CancelError::IpNotAllowed => crate::error::ErrorKind::User,
CancelError::NotFound => crate::error::ErrorKind::User,
CancelError::IpNotAllowed
| CancelError::VpcEndpointIdNotAllowed
| CancelError::NotFound => crate::error::ErrorKind::User,
CancelError::AuthError(_) => crate::error::ErrorKind::ControlPlane,
CancelError::InternalError => crate::error::ErrorKind::Service,
}
Expand Down Expand Up @@ -267,11 +273,12 @@ impl CancellationHandler {
/// Will fetch IP allowlist internally.
///
/// return Result primarily for tests
pub(crate) async fn cancel_session<T: BackendIpAllowlist>(
pub(crate) async fn cancel_session<T: ControlPlaneApi>(
&self,
key: CancelKeyData,
ctx: RequestContext,
check_allowed: bool,
check_ip_allowed: bool,
check_vpc_allowed: bool,
auth_backend: &T,
) -> Result<(), CancelError> {
let subnet_key = match ctx.peer_addr() {
Expand Down Expand Up @@ -306,11 +313,11 @@ impl CancellationHandler {
return Err(CancelError::NotFound);
};

if check_allowed {
if check_ip_allowed {
let ip_allowlist = auth_backend
.get_allowed_ips(&ctx, &cancel_closure.user_info)
.await
.map_err(CancelError::AuthError)?;
.map_err(|e| CancelError::AuthError(e.into()))?;

if !check_peer_addr_is_in_list(&ctx.peer_addr(), &ip_allowlist) {
// log it here since cancel_session could be spawned in a task
Expand All @@ -322,6 +329,40 @@ impl CancellationHandler {
}
}

// check if a VPC endpoint ID is coming in and if yes, if it's allowed
let access_blocks = auth_backend
.get_block_public_or_vpc_access(&ctx, &cancel_closure.user_info)
.await
.map_err(|e| CancelError::AuthError(e.into()))?;

if check_vpc_allowed {
if access_blocks.vpc_access_blocked {
return Err(CancelError::AuthError(AuthError::NetworkNotAllowed));
}

let incoming_vpc_endpoint_id = match ctx.extra() {
None => return Err(CancelError::AuthError(AuthError::MissingVPCEndpointId)),
Some(ConnectionInfoExtra::Aws { vpce_id }) => {
// Convert the vcpe_id to a string
String::from_utf8(vpce_id.to_vec()).unwrap_or_default()
}
Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(),
};

let allowed_vpc_endpoint_ids = auth_backend
.get_allowed_vpc_endpoint_ids(&ctx, &cancel_closure.user_info)
.await
.map_err(|e| CancelError::AuthError(e.into()))?;
// TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that.
if !allowed_vpc_endpoint_ids.is_empty()
&& !allowed_vpc_endpoint_ids.contains(&incoming_vpc_endpoint_id)
{
return Err(CancelError::VpcEndpointIdNotAllowed);
}
} else if access_blocks.public_access_blocked {
return Err(CancelError::VpcEndpointIdNotAllowed);
}

Metrics::get()
.proxy
.cancellation_requests_total
Expand Down
3 changes: 2 additions & 1 deletion proxy/src/console_redirect_proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
cancel_key_data,
ctx,
config.authentication_config.ip_allowlist_check_enabled,
backend,
config.authentication_config.is_vpc_acccess_proxy,
backend.get_api(),
)
.await
.inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok();
Expand Down
3 changes: 2 additions & 1 deletion proxy/src/proxy/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,8 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
cancel_key_data,
ctx,
config.authentication_config.ip_allowlist_check_enabled,
auth_backend,
config.authentication_config.is_vpc_acccess_proxy,
auth_backend.get_api(),
)
.await
.inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok();
Expand Down
11 changes: 2 additions & 9 deletions proxy/src/redis/kv_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,7 @@ pub struct RedisKVClient {
}

impl RedisKVClient {
pub fn new(
client: ConnectionWithCredentialsProvider,
info: &'static [RateBucketInfo],
) -> Self {
pub fn new(client: ConnectionWithCredentialsProvider, info: &'static [RateBucketInfo]) -> Self {
Self {
client,
limiter: GlobalRateLimiter::new(info.into()),
Expand Down Expand Up @@ -42,11 +39,7 @@ impl RedisKVClient {
return Err(anyhow::anyhow!("Rate limit exceeded"));
}

match self
.client
.hset(&key, &field, &value)
.await
{
match self.client.hset(&key, &field, &value).await {
Ok(()) => return Ok(()),
Err(e) => {
tracing::error!("failed to set a key-value pair: {e}");
Expand Down
4 changes: 2 additions & 2 deletions proxy/src/redis/notifications.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use uuid::Uuid;

use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
use crate::cache::project_info::ProjectInfoCache;
use crate::intern::{ProjectIdInt, RoleNameInt, AccountIdInt};
use crate::intern::{AccountIdInt, ProjectIdInt, RoleNameInt};
use crate::metrics::{Metrics, RedisErrors, RedisEventsCount};

const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
Expand Down Expand Up @@ -265,7 +265,7 @@ fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
.invalidate_role_secret_for_project(
password_update.project_id,
password_update.role_name,
),
),
Notification::UnknownTopic => unreachable!(),
}
}
Expand Down

0 comments on commit 83c48ec

Please sign in to comment.