Skip to content

Commit

Permalink
Merge pull request #686 from hatoo/preload-tls-config
Browse files Browse the repository at this point in the history
Preload tls config
  • Loading branch information
hatoo authored Feb 4, 2025
2 parents 8ee33b8 + 87d0929 commit 365d8ab
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 92 deletions.
76 changes: 6 additions & 70 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,15 +179,16 @@ pub struct Client {
pub timeout: Option<std::time::Duration>,
pub redirect_limit: usize,
pub disable_keepalive: bool,
pub insecure: bool,
pub proxy_url: Option<Url>,
pub aws_config: Option<AwsSignatureConfig>,
#[cfg(unix)]
pub unix_socket: Option<std::path::PathBuf>,
#[cfg(feature = "vsock")]
pub vsock_addr: Option<tokio_vsock::VsockAddr>,
#[cfg(feature = "rustls")]
pub root_cert_store: Arc<rustls::RootCertStore>,
pub rustls_configs: crate::tls_config::RuslsConfigs,
#[cfg(all(feature = "native-tls", not(feature = "rustls")))]
pub native_tls_connectors: crate::tls_config::NativeTlsConnectors,
}

struct ClientStateHttp1 {
Expand Down Expand Up @@ -456,18 +457,7 @@ impl Client {
where
S: AsyncRead + AsyncWrite + Unpin,
{
let mut connector_builder = native_tls::TlsConnector::builder();
if self.insecure {
connector_builder
.danger_accept_invalid_certs(true)
.danger_accept_invalid_hostnames(true);
}

if is_http2 {
connector_builder.request_alpns(&["h2"]);
}

let connector = tokio_native_tls::TlsConnector::from(connector_builder.build()?);
let connector = self.native_tls_connectors.connector(is_http2);
let stream = connector
.connect(url.host_str().ok_or(ClientError::HostNotFound)?, stream)
.await?;
Expand All @@ -485,18 +475,8 @@ impl Client {
where
S: AsyncRead + AsyncWrite + Unpin,
{
let mut config = rustls::ClientConfig::builder()
.with_root_certificates(self.root_cert_store.clone())
.with_no_client_auth();
if self.insecure {
config
.dangerous()
.set_certificate_verifier(Arc::new(AcceptAnyServerCert));
}
if is_http2 {
config.alpn_protocols = vec![b"h2".to_vec()];
}
let connector = tokio_rustls::TlsConnector::from(Arc::new(config));
let connector =
tokio_rustls::TlsConnector::from(self.rustls_configs.config(is_http2).clone());
let domain = rustls_pki_types::ServerName::try_from(
url.host_str().ok_or(ClientError::HostNotFound)?,
)?;
Expand Down Expand Up @@ -849,50 +829,6 @@ impl Client {
}
}

/// A server certificate verifier that accepts any certificate.
#[cfg(feature = "rustls")]
#[derive(Debug)]
struct AcceptAnyServerCert;

#[cfg(feature = "rustls")]
impl rustls::client::danger::ServerCertVerifier for AcceptAnyServerCert {
fn verify_server_cert(
&self,
_end_entity: &rustls_pki_types::CertificateDer<'_>,
_intermediates: &[rustls_pki_types::CertificateDer<'_>],
_server_name: &rustls_pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls_pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}

fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &rustls_pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}

fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &rustls_pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}

fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
rustls::crypto::CryptoProvider::get_default()
.unwrap()
.signature_verification_algorithms
.supported_schemes()
}
}

/// Check error and decide whether to cancel the connection
fn is_cancel_error(res: &Result<RequestResult, ClientError>) -> bool {
matches!(res, Err(ClientError::Deadline)) || is_too_many_open_files(res)
Expand Down
14 changes: 3 additions & 11 deletions src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,24 +86,16 @@ mod test_db {
timeout: None,
redirect_limit: 0,
disable_keepalive: false,
insecure: false,
proxy_url: None,
aws_config: None,
#[cfg(unix)]
unix_socket: None,
#[cfg(feature = "vsock")]
vsock_addr: None,
#[cfg(feature = "rustls")]
// Cache rustls_native_certs::load_native_certs() because it's expensive.
root_cert_store: {
let mut root_cert_store = rustls::RootCertStore::empty();
for cert in
rustls_native_certs::load_native_certs().expect("could not load platform certs")
{
root_cert_store.add(cert).unwrap();
}
std::sync::Arc::new(root_cert_store)
},
rustls_configs: crate::tls_config::RuslsConfigs::new(false),
#[cfg(all(feature = "native-tls", not(feature = "rustls")))]
native_tls_connectors: crate::tls_config::NativeTlsConnectors::new(false),
};
let result = store(&client, ":memory:", start, &test_vec);
assert_eq!(result.unwrap(), 2);
Expand Down
15 changes: 4 additions & 11 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ mod pcg64si;
mod printer;
mod result_data;
mod timescale;
mod tls_config;
mod url_generator;

#[cfg(not(target_env = "msvc"))]
Expand Down Expand Up @@ -535,23 +536,15 @@ async fn run() -> anyhow::Result<()> {
timeout: opts.timeout.map(|d| d.into()),
redirect_limit: opts.redirect,
disable_keepalive: opts.disable_keepalive,
insecure: opts.insecure,
proxy_url: opts.proxy,
#[cfg(unix)]
unix_socket: opts.unix_socket,
#[cfg(feature = "vsock")]
vsock_addr: opts.vsock_addr.map(|v| v.0),
#[cfg(feature = "rustls")]
// Cache rustls_native_certs::load_native_certs() because it's expensive.
root_cert_store: {
let mut root_cert_store = rustls::RootCertStore::empty();
for cert in
rustls_native_certs::load_native_certs().expect("could not load platform certs")
{
root_cert_store.add(cert).unwrap();
}
std::sync::Arc::new(root_cert_store)
},
rustls_configs: tls_config::RuslsConfigs::new(opts.insecure),
#[cfg(all(feature = "native-tls", not(feature = "rustls")))]
native_tls_connectors: tls_config::NativeTlsConnectors::new(opts.insecure),
});

if !opts.no_pre_lookup {
Expand Down
129 changes: 129 additions & 0 deletions src/tls_config.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
#[cfg(feature = "rustls")]
pub struct RuslsConfigs {
no_alpn: std::sync::Arc<rustls::ClientConfig>,
alpn_h2: std::sync::Arc<rustls::ClientConfig>,
}

#[cfg(feature = "rustls")]
impl RuslsConfigs {
pub fn new(insecure: bool) -> Self {
use std::sync::Arc;

let mut root_cert_store = rustls::RootCertStore::empty();
for cert in rustls_native_certs::load_native_certs().expect("could not load platform certs")
{
root_cert_store.add(cert).unwrap();
}
let mut config = rustls::ClientConfig::builder()
.with_root_certificates(root_cert_store.clone())
.with_no_client_auth();
if insecure {
config
.dangerous()
.set_certificate_verifier(Arc::new(AcceptAnyServerCert));
}

let mut no_alpn = config.clone();
no_alpn.alpn_protocols = vec![];
let mut alpn_h2 = config;
alpn_h2.alpn_protocols = vec![b"h2".to_vec()];
Self {
no_alpn: Arc::new(no_alpn),
alpn_h2: Arc::new(alpn_h2),
}
}

pub fn config(&self, is_http2: bool) -> &std::sync::Arc<rustls::ClientConfig> {
if is_http2 {
&self.alpn_h2
} else {
&self.no_alpn
}
}
}

#[cfg(all(feature = "native-tls", not(feature = "rustls")))]
pub struct NativeTlsConnectors {
pub no_alpn: tokio_native_tls::TlsConnector,
pub alpn_h2: tokio_native_tls::TlsConnector,
}

#[cfg(all(feature = "native-tls", not(feature = "rustls")))]
impl NativeTlsConnectors {
pub fn new(insecure: bool) -> Self {
let new = |is_http2: bool| {
let mut connector_builder = native_tls::TlsConnector::builder();
if insecure {
connector_builder
.danger_accept_invalid_certs(true)
.danger_accept_invalid_hostnames(true);
}

if is_http2 {
connector_builder.request_alpns(&["h2"]);
}

connector_builder
.build()
.expect("Failed to build native_tls::TlsConnector")
.into()
};

Self {
no_alpn: new(false),
alpn_h2: new(true),
}
}

pub fn connector(&self, is_http2: bool) -> &tokio_native_tls::TlsConnector {
if is_http2 {
&self.alpn_h2
} else {
&self.no_alpn
}
}
}

/// A server certificate verifier that accepts any certificate.
#[cfg(feature = "rustls")]
#[derive(Debug)]
pub struct AcceptAnyServerCert;

#[cfg(feature = "rustls")]
impl rustls::client::danger::ServerCertVerifier for AcceptAnyServerCert {
fn verify_server_cert(
&self,
_end_entity: &rustls_pki_types::CertificateDer<'_>,
_intermediates: &[rustls_pki_types::CertificateDer<'_>],
_server_name: &rustls_pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls_pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}

fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &rustls_pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}

fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &rustls_pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}

fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
rustls::crypto::CryptoProvider::get_default()
.unwrap()
.signature_verification_algorithms
.supported_schemes()
}
}

0 comments on commit 365d8ab

Please sign in to comment.