Skip to content

Commit

Permalink
added wrapper type and todos (#348)
Browse files Browse the repository at this point in the history
* added wrapper type and todos

* modified code to support ipv4 or ipv6 connection

* added function for iplookup and made enums in separate file

* added the required enum and made the required changes

* added functionality to respect ip version preferences

* fixed the logic for calcuating the delay

* addressed the parameters issue
  • Loading branch information
AnkurRathore authored Feb 8, 2025
1 parent bb55530 commit caf3a0d
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 39 deletions.
1 change: 1 addition & 0 deletions rama-net/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pub mod address;
pub mod asn;
pub mod client;
pub mod forwarded;
pub mod mode;
pub mod stream;
pub mod user;

Expand Down
39 changes: 39 additions & 0 deletions rama-net/src/mode.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
/// Enum representing the IP modes that can be used by the DNS resolver.
pub enum DnsResolveIpMode {
#[default]
Dual,
SingleIpV4,
SingleIpV6,
DualPreferIpV4,
}

impl DnsResolveIpMode {
/// checks if IPv4 is supported in current mode
pub fn ipv4_supported(&self) -> bool {
matches!(
self,
DnsResolveIpMode::Dual
| DnsResolveIpMode::SingleIpV4
| DnsResolveIpMode::DualPreferIpV4
)
}

/// checks if IPv6 is supported in current mode
pub fn ipv6_supported(&self) -> bool {
matches!(
self,
DnsResolveIpMode::Dual
| DnsResolveIpMode::SingleIpV6
| DnsResolveIpMode::DualPreferIpV4
)
}
}
///Mode for establishing a connection
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum ConnectIpMode {
#[default]
Dual,
Ipv4,
Ipv6,
}
119 changes: 80 additions & 39 deletions rama-tcp/src/client/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use rama_core::{
};
use rama_dns::{DnsOverwrite, DnsResolver, HickoryDns};
use rama_net::address::{Authority, Domain, Host};
use rama_net::mode::{ConnectIpMode, DnsResolveIpMode};
use std::{
future::Future,
net::{IpAddr, SocketAddr},
Expand Down Expand Up @@ -133,6 +134,18 @@ where
let domain = match host {
Host::Name(domain) => domain,
Host::Address(ip) => {
let mode = ConnectIpMode::Dual;
//check if IP Version is allowed
match (ip, mode) {
(IpAddr::V4(_), ConnectIpMode::Ipv6) => {
return Err(OpaqueError::from_display("IPv4 address is not allowed"));
}
(IpAddr::V6(_), ConnectIpMode::Ipv4) => {
return Err(OpaqueError::from_display("IPv6 address is not allowed"));
}
_ => (),
}

// if the authority is already defined as an IP address, we can directly connect to it
let addr = (ip, port).into();
let stream = connector
Expand All @@ -150,8 +163,10 @@ where
ctx,
domain.clone(),
port,
dns_overwrite.deref().clone(),
DnsResolveIpMode::Dual, // Use the mode from the overwrite
dns_overwrite.deref().clone(), // Convert DnsOverwrite to a DnsResolver
connector.clone(),
ConnectIpMode::Dual,
)
.await
{
Expand All @@ -160,63 +175,71 @@ where
}
}


//... otherwise we'll try to establish a connection,
// with dual-stack parallel connections...

tcp_connect_inner(ctx, domain, port, dns, connector).await
tcp_connect_inner(
ctx,
domain,
port,
DnsResolveIpMode::Dual,
dns,
connector,
ConnectIpMode::Dual,
)
.await
}

async fn tcp_connect_inner<State, Dns, Connector>(
ctx: &Context<State>,
domain: Domain,
port: u16,
dns_mode: DnsResolveIpMode,
dns: Dns,
connector: Connector,
connect_mode: ConnectIpMode,
) -> Result<(TcpStream, SocketAddr), OpaqueError>
where
State: Clone + Send + Sync + 'static,
Dns: DnsResolver<Error: Into<BoxError>> + Clone,
Connector: TcpStreamConnector<Error: Into<BoxError> + Send + 'static> + Clone,
{
let (tx, mut rx) = channel(1);

let connected = Arc::new(AtomicBool::new(false));
let sem = Arc::new(Semaphore::new(3));

// IPv6
let ipv6_tx = tx.clone();
let ipv6_domain = domain.clone();
let ipv6_connected = connected.clone();
let ipv6_sem = sem.clone();
ctx.spawn(tcp_connect_inner_branch(
dns.clone(),
connector.clone(),
IpKind::Ipv6,
ipv6_domain,
port,
ipv6_tx,
ipv6_connected,
ipv6_sem,
));

// IPv4
let ipv4_tx = tx;
let ipv4_domain = domain.clone();
let ipv4_connected = connected.clone();
let ipv4_sem = sem;
ctx.spawn(tcp_connect_inner_branch(
dns,
connector,
IpKind::Ipv4,
ipv4_domain,
port,
ipv4_tx,
ipv4_connected,
ipv4_sem,
));
if dns_mode.ipv4_supported() {
ctx.spawn(tcp_connect_inner_branch(
dns_mode,
dns.clone(),
connect_mode,
connector.clone(),
IpKind::Ipv4,
domain.clone(),
port,
tx.clone(),
connected.clone(),
sem.clone(),
));
}


if dns_mode.ipv6_supported() {
ctx.spawn(tcp_connect_inner_branch(
dns_mode,
dns.clone(),
connect_mode,
connector.clone(),
IpKind::Ipv6,
domain.clone(),
port,
tx.clone(),
connected.clone(),
sem.clone(),
));
}

// wait for the first connection to succeed,
// ignore the rest of the connections (sorry, but not sorry)
if let Some((stream, addr)) = rx.recv().await {
connected.store(true, Ordering::Release);
return Ok((stream, addr));
Expand All @@ -227,6 +250,7 @@ where
)))
}


#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
enum IpKind {
Ipv4,
Expand All @@ -235,7 +259,9 @@ enum IpKind {

#[allow(clippy::too_many_arguments)]
async fn tcp_connect_inner_branch<Dns, Connector>(
dns_mode: DnsResolveIpMode,
dns: Dns,
connect_mode: ConnectIpMode,
connector: Connector,
ip_kind: IpKind,
domain: Domain,
Expand Down Expand Up @@ -266,18 +292,33 @@ async fn tcp_connect_inner_branch<Dns, Connector>(
},
};

let (ipv4_delay_scalar, ipv6_delay_scalar) = match dns_mode {
DnsResolveIpMode::DualPreferIpV4 | DnsResolveIpMode::SingleIpV4 => (15 * 2, 21 * 2),
_ => (21 * 2, 15 * 2),
};
for (index, ip) in ip_it.enumerate() {
let addr = (ip, port).into();

let sem = sem.clone();
let sem = match (ip.is_ipv4(), connect_mode) {
(true, ConnectIpMode::Ipv6) => {
tracing::trace!("[{ip_kind:?}] #{index}: abort connect loop to {addr} (IPv4 address is not allowed)");
continue;
}
(false, ConnectIpMode::Ipv4) => {
tracing::trace!("[{ip_kind:?}] #{index}: abort connect loop to {addr} (IPv6 address is not allowed)");
continue;
}
_ => sem.clone(),
};

let tx = tx.clone();
let connected = connected.clone();

// back off retries exponentially
if index > 0 {
let delay = match ip_kind {
IpKind::Ipv4 => Duration::from_micros((21 * 2 * index) as u64),
IpKind::Ipv6 => Duration::from_micros((15 * 2 * index) as u64),
IpKind::Ipv4 => Duration::from_micros((ipv4_delay_scalar * index) as u64),
IpKind::Ipv6 => Duration::from_micros((ipv6_delay_scalar * index) as u64),
};
tokio::time::sleep(delay).await;
}
Expand Down

0 comments on commit caf3a0d

Please sign in to comment.