Skip to content

Commit

Permalink
Avoid boxing in DefaultConnector chain
Browse files Browse the repository at this point in the history
  • Loading branch information
algesten committed Jan 3, 2025
1 parent 58af416 commit 2c2af87
Show file tree
Hide file tree
Showing 11 changed files with 372 additions and 167 deletions.
20 changes: 15 additions & 5 deletions src/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::middleware::MiddlewareNext;
use crate::pool::ConnectionPool;
use crate::resolver::{DefaultResolver, Resolver};
use crate::send_body::AsSendBody;
use crate::transport::{Connector, DefaultConnector};
use crate::transport::{boxed_connector, Connector, DefaultConnector, Transport};
use crate::{Error, RequestBuilder, SendBody};
use crate::{WithBody, WithoutBody};

Expand Down Expand Up @@ -96,18 +96,18 @@ pub struct Agent {
impl Agent {
/// Creates an agent with defaults.
pub fn new_with_defaults() -> Self {
Self::with_parts(
Self::with_parts_inner(
Config::default(),
DefaultConnector::default(),
Box::new(DefaultConnector::default()),
DefaultResolver::default(),
)
}

/// Creates an agent with config.
pub fn new_with_config(config: Config) -> Self {
Self::with_parts(
Self::with_parts_inner(
config,
DefaultConnector::default(),
Box::new(DefaultConnector::default()),
DefaultResolver::default(),
)
}
Expand All @@ -123,6 +123,16 @@ impl Agent {
///
/// _This is low level API that isn't for regular use of ureq._
pub fn with_parts(config: Config, connector: impl Connector, resolver: impl Resolver) -> Self {
let boxed = boxed_connector(connector);
Self::with_parts_inner(config, boxed, resolver)
}

/// Inner helper to avoid additional boxing of the [`DefaultConnector`].
fn with_parts_inner(
config: Config,
connector: Box<dyn Connector<(), Out = Box<dyn Transport>>>,
resolver: impl Resolver,
) -> Self {
let pool = Arc::new(ConnectionPool::new(connector, &config));

Agent {
Expand Down
6 changes: 3 additions & 3 deletions src/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ use crate::util::DebugAuthority;
use crate::Error;

pub(crate) struct ConnectionPool {
connector: Box<dyn Connector>,
connector: Box<dyn Connector<Out = Box<dyn Transport>>>,
pool: Arc<Mutex<Pool>>,
}

impl ConnectionPool {
pub fn new(connector: impl Connector, config: &Config) -> Self {
pub fn new(connector: Box<dyn Connector<Out = Box<dyn Transport>>>, config: &Config) -> Self {
ConnectionPool {
connector: Box::new(connector),
connector,
pool: Arc::new(Mutex::new(Pool::new(config))),
}
}
Expand Down
8 changes: 5 additions & 3 deletions src/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,12 +188,14 @@ impl Proxy {
/// wrapped in TLS.
pub struct ConnectProxyConnector;

impl Connector for ConnectProxyConnector {
impl<In: Transport> Connector<In> for ConnectProxyConnector {
type Out = In;

fn connect(
&self,
details: &ConnectionDetails,
chained: Option<Box<dyn Transport>>,
) -> Result<Option<Box<dyn Transport>>, Error> {
chained: Option<In>,
) -> Result<Option<Self::Out>, Error> {
let Some(transport) = chained else {
return Ok(None);
};
Expand Down
20 changes: 11 additions & 9 deletions src/tls/native_tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@ pub struct NativeTlsConnector {
connector: OnceCell<Arc<TlsConnector>>,
}

impl Connector for NativeTlsConnector {
impl<In: Transport> Connector<In> for NativeTlsConnector {
type Out = Either<In, NativeTlsTransport>;

fn connect(
&self,
details: &ConnectionDetails,
chained: Option<Box<dyn Transport>>,
) -> Result<Option<Box<dyn Transport>>, Error> {
chained: Option<In>,
) -> Result<Option<Self::Out>, Error> {
let Some(transport) = chained else {
panic!("NativeTlsConnector requires a chained transport");
};
Expand All @@ -35,12 +37,12 @@ impl Connector for NativeTlsConnector {
// already, otherwise use chained transport as is.
if !details.needs_tls() || transport.is_tls() {
trace!("Skip");
return Ok(Some(transport));
return Ok(Some(Either::A(transport)));
}

if details.config.tls_config().provider != TlsProvider::NativeTls {
debug!("Skip because config is not set to Native TLS");
return Ok(Some(transport));
return Ok(Some(Either::A(transport)));
}

trace!("Try wrap TLS");
Expand All @@ -67,19 +69,19 @@ impl Connector for NativeTlsConnector {
.host()
.to_string();

let adapter = TransportAdapter::new(transport);
let adapter = TransportAdapter::new(transport.boxed());
let stream = LazyStream::Unstarted(Some((connector, domain, adapter)));

let buffers = LazyBuffers::new(
details.config.input_buffer_size(),
details.config.output_buffer_size(),
);

let transport = Box::new(NativeTlsTransport { buffers, stream });
let transport = NativeTlsTransport { buffers, stream };

debug!("Wrapped TLS");

Ok(Some(transport))
Ok(Some(Either::B(transport)))
}
}

Expand Down Expand Up @@ -167,7 +169,7 @@ fn pemify(der: &[u8], label: &'static str) -> Result<String, Error> {
Ok(pem)
}

struct NativeTlsTransport {
pub struct NativeTlsTransport {
buffers: LazyBuffers,
stream: LazyStream,
}
Expand Down
22 changes: 12 additions & 10 deletions src/tls/rustls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use rustls_pki_types::{PrivateSec1KeyDer, ServerName};
use crate::tls::cert::KeyKind;
use crate::tls::{RootCerts, TlsProvider};
use crate::transport::{Buffers, ConnectionDetails, Connector, LazyBuffers};
use crate::transport::{NextTimeout, Transport, TransportAdapter};
use crate::transport::{Either, NextTimeout, Transport, TransportAdapter};
use crate::Error;

use super::TlsConfig;
Expand All @@ -25,12 +25,14 @@ pub struct RustlsConnector {
config: OnceCell<Arc<ClientConfig>>,
}

impl Connector for RustlsConnector {
impl<In: Transport> Connector<In> for RustlsConnector {
type Out = Either<In, RustlsTransport>;

fn connect(
&self,
details: &ConnectionDetails,
chained: Option<Box<dyn Transport>>,
) -> Result<Option<Box<dyn Transport>>, Error> {
chained: Option<In>,
) -> Result<Option<Self::Out>, Error> {
let Some(transport) = chained else {
panic!("RustlConnector requires a chained transport");
};
Expand All @@ -39,12 +41,12 @@ impl Connector for RustlsConnector {
// already, otherwise use chained transport as is.
if !details.needs_tls() || transport.is_tls() {
trace!("Skip");
return Ok(Some(transport));
return Ok(Some(Either::A(transport)));
}

if details.config.tls_config().provider != TlsProvider::Rustls {
debug!("Skip because config is not set to Rustls");
return Ok(Some(transport));
return Ok(Some(Either::A(transport)));
}

trace!("Try wrap in TLS");
Expand All @@ -71,19 +73,19 @@ impl Connector for RustlsConnector {
let conn = ClientConnection::new(config, name)?;
let stream = StreamOwned {
conn,
sock: TransportAdapter::new(transport),
sock: TransportAdapter::new(transport.boxed()),
};

let buffers = LazyBuffers::new(
details.config.input_buffer_size(),
details.config.output_buffer_size(),
);

let transport = Box::new(RustlsTransport { buffers, stream });
let transport = RustlsTransport { buffers, stream };

debug!("Wrapped TLS");

Ok(Some(transport))
Ok(Some(Either::B(transport)))
}
}

Expand Down Expand Up @@ -166,7 +168,7 @@ fn build_config(tls_config: &TlsConfig) -> Arc<ClientConfig> {
Arc::new(config)
}

struct RustlsTransport {
pub struct RustlsTransport {
buffers: LazyBuffers,
stream: StreamOwned<ClientConnection, TransportAdapter>,
}
Expand Down
151 changes: 117 additions & 34 deletions src/unversioned/transport/chain.rs
Original file line number Diff line number Diff line change
@@ -1,50 +1,133 @@
use crate::Error;
use std::fmt;
use std::marker::PhantomData;

use super::{ConnectionDetails, Connector, Transport};
use super::{Connector, Transport};

/// Helper for a chain of connectors.
/// Two chained connectors called one after another.
///
/// Each step of the chain, can decide whether to:
///
/// * _Keep_ previous [`Transport`]
/// * _Wrap_ previous [`Transport`]
/// * _Ignore_ previous [`Transport`] in favor of some other connection.
///
/// For each new connection, the chain will be called one by one and the previously chained
/// transport will be provided to the next as an argument in [`Connector::connect()`].
///
/// The chain is always looped fully. There is no early return.
/// Created by calling [`Connector::chain`] on the first connector.
pub struct ChainedConnector<In, First, Second>(First, Second, PhantomData<In>);

impl<In, First, Second> Connector<In> for ChainedConnector<In, First, Second>
where
In: Transport,
First: Connector<In>,
Second: Connector<First::Out>,
{
type Out = Second::Out;

fn connect(
&self,
details: &super::ConnectionDetails,
chained: Option<In>,
) -> Result<Option<Self::Out>, crate::Error> {
let f_out = self.0.connect(details, chained)?;
self.1.connect(details, f_out)
}
}

impl<In, First, Second> ChainedConnector<In, First, Second> {
pub(crate) fn new(first: First, second: Second) -> Self {
ChainedConnector(first, second, PhantomData)
}
}

impl<In, First, Second> fmt::Debug for ChainedConnector<In, First, Second>
where
In: Transport,
First: Connector<In>,
Second: Connector<First::Out>,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("ChainedConnector")
.field(&self.0)
.field(&self.1)
.finish()
}
}

/// A selection between two transports.
#[derive(Debug)]
pub struct ChainedConnector {
chain: Vec<Box<dyn Connector>>,
pub enum Either<A, B> {
/// The first transport.
A(A),
/// The second transport.
B(B),
}

impl ChainedConnector {
/// Creates a new chain of connectors.
///
/// For each connection, the chain will be called one by one and the previously chained
/// transport will be provided to the next as an argument in [`Connector::connect()`].
///
/// The chain is always looped fully. There is no early return.
pub fn new(chain: impl IntoIterator<Item = Box<dyn Connector>>) -> Self {
Self {
chain: chain.into_iter().collect(),
impl<A: Transport, B: Transport> Transport for Either<A, B> {
fn buffers(&mut self) -> &mut dyn super::Buffers {
match self {
Either::A(a) => a.buffers(),
Either::B(b) => b.buffers(),
}
}

fn transmit_output(
&mut self,
amount: usize,
timeout: super::NextTimeout,
) -> Result<(), crate::Error> {
match self {
Either::A(a) => a.transmit_output(amount, timeout),
Either::B(b) => b.transmit_output(amount, timeout),
}
}

fn await_input(&mut self, timeout: super::NextTimeout) -> Result<bool, crate::Error> {
match self {
Either::A(a) => a.await_input(timeout),
Either::B(b) => b.await_input(timeout),
}
}

fn is_open(&mut self) -> bool {
match self {
Either::A(a) => a.is_open(),
Either::B(b) => b.is_open(),
}
}

fn is_tls(&self) -> bool {
match self {
Either::A(a) => a.is_tls(),
Either::B(b) => b.is_tls(),
}
}
}

impl Connector for ChainedConnector {
// Connector is implemented for () to start a chain of connectors.
//
// The `Out` transport is supposedly `()`, but this is never instantiated.
impl Connector<()> for () {
type Out = ();

fn connect(
&self,
details: &ConnectionDetails,
chained: Option<Box<dyn Transport>>,
) -> Result<Option<Box<dyn Transport>>, Error> {
let mut conn = chained;
_: &super::ConnectionDetails,
_: Option<()>,
) -> Result<Option<Self::Out>, crate::Error> {
Ok(None)
}
}

for connector in &self.chain {
conn = connector.connect(details, conn)?;
}
// () is a valid Transport for type reasons.
//
// It should never be instantiated as an actual transport.
impl Transport for () {
fn buffers(&mut self) -> &mut dyn super::Buffers {
panic!("Unit transport is not valid")
}

fn transmit_output(&mut self, _: usize, _: super::NextTimeout) -> Result<(), crate::Error> {
panic!("Unit transport is not valid")
}

fn await_input(&mut self, _: super::NextTimeout) -> Result<bool, crate::Error> {
panic!("Unit transport is not valid")
}

Ok(conn)
fn is_open(&mut self) -> bool {
panic!("Unit transport is not valid")
}
}
Loading

0 comments on commit 2c2af87

Please sign in to comment.