From c104a38960fb6886915cb8311861123b0142f59a Mon Sep 17 00:00:00 2001 From: Max Dymond Date: Thu, 4 Mar 2021 14:07:10 +0000 Subject: [PATCH 1/2] CPP-928 Ensure server name information flows through from contact point configuration Previously during Address name resolution, the server name information for a given Address was lost. This fix ensures that the server name information flows through during the name resolution process for a given Address. --- include/cassandra.h | 6 ++++++ src/address.cpp | 5 +++-- src/address.hpp | 2 +- src/client_insights.cpp | 5 +++-- src/cluster_config.cpp | 3 ++- src/cluster_metadata_resolver.cpp | 4 ++-- src/resolver.hpp | 11 +++++++---- src/socket_connector.cpp | 4 ++-- src/ssl/ssl_openssl_impl.cpp | 4 ++-- tests/src/unit/tests/test_address.cpp | 6 ++++-- tests/src/unit/tests/test_resolver.cpp | 27 +++++++++++++------------- tests/src/unit/tests/test_socket.cpp | 3 ++- 12 files changed, 48 insertions(+), 32 deletions(-) diff --git a/include/cassandra.h b/include/cassandra.h index 63341bb3c..1b50508cc 100644 --- a/include/cassandra.h +++ b/include/cassandra.h @@ -4602,6 +4602,12 @@ cass_ssl_add_trusted_cert_n(CassSsl* ssl, * common name or one of its subject alternative names. This implies the * certificate is also present. Hostname resolution must also be enabled. * + * Notes: + * - CASS_SSL_VERIFY_PEER_IDENTITY and CASS_SSL_VERIFY_PEER_IDENTITY_DNS are + * mutually exclusive options. + * - The certificate Common Name is only checked against the IP address or + * hostname if there are no Subject Alternative Names in the certificate. + * * Default: CASS_SSL_VERIFY_PEER_CERT * * @public @memberof CassSsl diff --git a/src/address.cpp b/src/address.cpp index 752bd9ae3..c2bb5bf15 100644 --- a/src/address.cpp +++ b/src/address.cpp @@ -75,8 +75,9 @@ Address::Address(const uint8_t* address, uint8_t address_length, int port) } } -Address::Address(const struct sockaddr* addr) - : family_(UNRESOLVED) +Address::Address(const struct sockaddr* addr, const String& server_name) + : server_name_(server_name) + , family_(UNRESOLVED) , port_(0) { if (addr->sa_family == AF_INET) { const struct sockaddr_in* addr_in = reinterpret_cast(addr); diff --git a/src/address.hpp b/src/address.hpp index c74c6253a..87c508e0a 100644 --- a/src/address.hpp +++ b/src/address.hpp @@ -73,7 +73,7 @@ class Address : public Allocated { Address(const Address& other, const String& server_name); Address(const String& hostname_or_address, int port, const String& server_name = String()); Address(const uint8_t* address, uint8_t address_length, int port); - Address(const struct sockaddr* addr); + Address(const struct sockaddr* addr, const String& server_name); bool equals(const Address& other, bool with_port = true) const; diff --git a/src/client_insights.cpp b/src/client_insights.cpp index 351c9c656..d4ad3cdbc 100644 --- a/src/client_insights.cpp +++ b/src/client_insights.cpp @@ -635,7 +635,7 @@ class StartupMessageHandler : public RefCounted { new MultiResolver(bind_callback(&StartupMessageHandler::on_resolve, this))); } resolver->resolve(connection_->loop(), contact_point.hostname_or_address(), port, - config_.resolve_timeout_ms()); + config_.resolve_timeout_ms(), contact_point.server_name()); } } @@ -668,7 +668,8 @@ class StartupMessageHandler : public RefCounted { Address::SocketStorage name; int namelen = sizeof(name); if (uv_tcp_getsockname(tcp, name.addr(), &namelen) == 0) { - Address address(name.addr()); + // Pass a blank server name as this is a temporary address. + Address address(name.addr(), String()); if (address.is_valid_and_resolved()) { return address.to_string(); } diff --git a/src/cluster_config.cpp b/src/cluster_config.cpp index 9fac5b326..c80a69756 100644 --- a/src/cluster_config.cpp +++ b/src/cluster_config.cpp @@ -131,7 +131,8 @@ CassError cass_cluster_set_contact_points_n(CassCluster* cluster, const char* co explode(String(contact_points, contact_points_length), exploded); for (Vector::const_iterator it = exploded.begin(), end = exploded.end(); it != end; ++it) { - cluster->config().contact_points().push_back(Address(*it, -1)); + // Treat the address string as the server name. + cluster->config().contact_points().push_back(Address(*it, -1, *it)); } } return CASS_OK; diff --git a/src/cluster_metadata_resolver.cpp b/src/cluster_metadata_resolver.cpp index 78ef0c70d..bfe10e382 100644 --- a/src/cluster_metadata_resolver.cpp +++ b/src/cluster_metadata_resolver.cpp @@ -39,13 +39,13 @@ class DefaultClusterMetadataResolver : public ClusterMetadataResolver { int port = it->port() <= 0 ? port_ : it->port(); if (it->is_resolved()) { - resolved_contact_points_.push_back(Address(it->hostname_or_address(), port)); + resolved_contact_points_.push_back(Address(it->hostname_or_address(), port, it->server_name())); } else { if (!resolver_) { resolver_.reset( new MultiResolver(bind_callback(&DefaultClusterMetadataResolver::on_resolve, this))); } - resolver_->resolve(loop, it->hostname_or_address(), port, resolve_timeout_ms_); + resolver_->resolve(loop, it->hostname_or_address(), port, resolve_timeout_ms_, it->server_name()); } } diff --git a/src/resolver.hpp b/src/resolver.hpp index 2895d81ef..c21663bf2 100644 --- a/src/resolver.hpp +++ b/src/resolver.hpp @@ -47,8 +47,9 @@ class Resolver : public RefCounted { SUCCESS }; - Resolver(const String& hostname, int port, const Callback& callback) + Resolver(const String& hostname, int port, const Callback& callback, const String& server_name) : hostname_(hostname) + , server_name_(server_name) , port_(port) , status_(NEW) , callback_(callback) { @@ -139,7 +140,7 @@ class Resolver : public RefCounted { bool init_addresses(struct addrinfo* res) { bool status = false; do { - Address address(res->ai_addr); + Address address(res->ai_addr, server_name_); if (address.is_valid_and_resolved()) { addresses_.push_back(address); status = true; @@ -153,6 +154,7 @@ class Resolver : public RefCounted { uv_getaddrinfo_t req_; Timer timer_; String hostname_; + String server_name_; int port_; Status status_; int uv_status_; @@ -175,10 +177,11 @@ class MultiResolver : public RefCounted { const Resolver::Vec& resolvers() { return resolvers_; } void resolve(uv_loop_t* loop, const String& host, int port, uint64_t timeout, - struct addrinfo* hints = NULL) { + const String& server_name, struct addrinfo* hints = NULL) { inc_ref(); Resolver::Ptr resolver( - new Resolver(host, port, bind_callback(&MultiResolver::on_resolve, this))); + new Resolver(host, port, bind_callback(&MultiResolver::on_resolve, this), + server_name)); resolver->resolve(loop, timeout, hints); resolvers_.push_back(resolver); remaining_++; diff --git a/src/socket_connector.cpp b/src/socket_connector.cpp index 3d396a4df..6ce5c1d1b 100644 --- a/src/socket_connector.cpp +++ b/src/socket_connector.cpp @@ -120,11 +120,11 @@ void SocketConnector::connect(uv_loop_t* loop) { hostname_ = address_.hostname_or_address(); resolver_.reset(new Resolver(hostname_, address_.port(), - bind_callback(&SocketConnector::on_resolve, this))); + bind_callback(&SocketConnector::on_resolve, this), + address_.server_name())); resolver_->resolve(loop, settings_.resolve_timeout_ms); } else { resolved_address_ = address_; - if (settings_.hostname_resolution_enabled) { // Run hostname resolution then connect. name_resolver_.reset( new NameResolver(address_, bind_callback(&SocketConnector::on_name_resolve, this))); diff --git a/src/ssl/ssl_openssl_impl.cpp b/src/ssl/ssl_openssl_impl.cpp index 3b1124378..8a5cda6db 100644 --- a/src/ssl/ssl_openssl_impl.cpp +++ b/src/ssl/ssl_openssl_impl.cpp @@ -489,8 +489,8 @@ void OpenSslSession::verify() { return; } } else if (verify_flags_ & - CASS_SSL_VERIFY_PEER_IDENTITY_DNS) { // Match using hostnames (including wildcards) - switch (OpenSslVerifyIdentity::match_dns(peer_cert, hostname_)) { + CASS_SSL_VERIFY_PEER_IDENTITY_DNS) { // Match using the server name (including wildcards) + switch (OpenSslVerifyIdentity::match_dns(peer_cert, sni_server_name_)) { case OpenSslVerifyIdentity::MATCH: // Success break; diff --git a/tests/src/unit/tests/test_address.cpp b/tests/src/unit/tests/test_address.cpp index cae21eb25..ee4655523 100644 --- a/tests/src/unit/tests/test_address.cpp +++ b/tests/src/unit/tests/test_address.cpp @@ -17,9 +17,11 @@ #include #include "address.hpp" +#include "string.hpp" using datastax::internal::core::Address; using datastax::internal::core::AddressSet; +using datastax::String; TEST(AddressUnitTest, FromString) { EXPECT_TRUE(Address("127.0.0.1", 9042).is_resolved()); @@ -64,14 +66,14 @@ TEST(AddressUnitTest, CompareIPv6) { TEST(AddressUnitTest, ToSockAddrIPv4) { Address expected("127.0.0.1", 9042); Address::SocketStorage storage; - Address actual(expected.to_sockaddr(&storage)); + Address actual(expected.to_sockaddr(&storage), String()); EXPECT_EQ(expected, actual); } TEST(AddressUnitTest, ToSockAddrIPv6) { Address expected("::1", 9042); Address::SocketStorage storage; - Address actual(expected.to_sockaddr(&storage)); + Address actual(expected.to_sockaddr(&storage), String()); EXPECT_EQ(expected, actual); } diff --git a/tests/src/unit/tests/test_resolver.cpp b/tests/src/unit/tests/test_resolver.cpp index 2efbdd920..f73257d58 100644 --- a/tests/src/unit/tests/test_resolver.cpp +++ b/tests/src/unit/tests/test_resolver.cpp @@ -31,8 +31,9 @@ class ResolverUnitTest : public LoopTest { : status_(Resolver::NEW) {} Resolver::Ptr create(const String& hostname, int port = 9042) { + // Use the hostname as the TLS server name. return Resolver::Ptr( - new Resolver(hostname, port, bind_callback(&ResolverUnitTest::on_resolve, this))); + new Resolver(hostname, port, bind_callback(&ResolverUnitTest::on_resolve, this), hostname)); } MultiResolver::Ptr create_multi() { @@ -108,9 +109,9 @@ TEST_F(ResolverUnitTest, Cancel) { TEST_F(ResolverUnitTest, Multi) { MultiResolver::Ptr resolver(create_multi()); - resolver->resolve(loop(), "localhost", 9042, RESOLVE_TIMEOUT); - resolver->resolve(loop(), "localhost", 9042, RESOLVE_TIMEOUT); - resolver->resolve(loop(), "localhost", 9042, RESOLVE_TIMEOUT); + resolver->resolve(loop(), "localhost", 9042, RESOLVE_TIMEOUT, "localhost"); + resolver->resolve(loop(), "localhost", 9042, RESOLVE_TIMEOUT, "localhost"); + resolver->resolve(loop(), "localhost", 9042, RESOLVE_TIMEOUT, "localhost"); run_loop(); ASSERT_EQ(3u, resolvers().size()); for (Resolver::Vec::const_iterator it = resolvers().begin(), end = resolvers().end(); end != it; @@ -130,9 +131,9 @@ TEST_F(ResolverUnitTest, MultiTimeout) { starve_thread_pool(200); // Use shortest possible timeout for all requests - resolver->resolve(loop(), "localhost", 9042, 1); - resolver->resolve(loop(), "localhost", 9042, 1); - resolver->resolve(loop(), "localhost", 9042, 1); + resolver->resolve(loop(), "localhost", 9042, 1, "localhost"); + resolver->resolve(loop(), "localhost", 9042, 1, "localhost"); + resolver->resolve(loop(), "localhost", 9042, 1, "localhost"); run_loop(); ASSERT_EQ(3u, resolvers().size()); @@ -145,9 +146,9 @@ TEST_F(ResolverUnitTest, MultiTimeout) { TEST_F(ResolverUnitTest, MultiInvalid) { MultiResolver::Ptr resolver(create_multi()); - resolver->resolve(loop(), "doesnotexist1.dne", 9042, RESOLVE_TIMEOUT); - resolver->resolve(loop(), "doesnotexist2.dne", 9042, RESOLVE_TIMEOUT); - resolver->resolve(loop(), "doesnotexist3.dne", 9042, RESOLVE_TIMEOUT); + resolver->resolve(loop(), "doesnotexist1.dne", 9042, RESOLVE_TIMEOUT, "doesnotexist1.dne"); + resolver->resolve(loop(), "doesnotexist2.dne", 9042, RESOLVE_TIMEOUT, "doesnotexist2.dne"); + resolver->resolve(loop(), "doesnotexist3.dne", 9042, RESOLVE_TIMEOUT, "doesnotexist3.dne"); run_loop(); ASSERT_EQ(3u, resolvers().size()); for (Resolver::Vec::const_iterator it = resolvers().begin(), end = resolvers().end(); end != it; @@ -159,9 +160,9 @@ TEST_F(ResolverUnitTest, MultiInvalid) { TEST_F(ResolverUnitTest, MultiCancel) { MultiResolver::Ptr resolver(create_multi()); - resolver->resolve(loop(), "localhost", 9042, RESOLVE_TIMEOUT); - resolver->resolve(loop(), "localhost", 9042, RESOLVE_TIMEOUT); - resolver->resolve(loop(), "localhost", 9042, RESOLVE_TIMEOUT); + resolver->resolve(loop(), "localhost", 9042, RESOLVE_TIMEOUT, "localhost"); + resolver->resolve(loop(), "localhost", 9042, RESOLVE_TIMEOUT, "localhost"); + resolver->resolve(loop(), "localhost", 9042, RESOLVE_TIMEOUT, "localhost"); resolver->cancel(); run_loop(); ASSERT_EQ(3u, resolvers().size()); diff --git a/tests/src/unit/tests/test_socket.cpp b/tests/src/unit/tests/test_socket.cpp index 37bf730be..3939fae33 100644 --- a/tests/src/unit/tests/test_socket.cpp +++ b/tests/src/unit/tests/test_socket.cpp @@ -198,7 +198,8 @@ class SocketUnitTest : public LoopTest { } else { bool match = false; do { - Address address(res->ai_addr); + // Use a blank server name as it's not needed here. + Address address(res->ai_addr, String()); if (address.is_valid_and_resolved() && address == Address(DNS_IP_ADDRESS, 8888)) { match = true; break; From c130ae443dc76ce5371c6c21ec7a451a29b43293 Mon Sep 17 00:00:00 2001 From: Max Dymond Date: Thu, 4 Mar 2021 14:13:52 +0000 Subject: [PATCH 2/2] Iterate over all certificates in a trusted cert BIO, not just the first Previously the code which loaded a trusted certificate from file only assumed that there was a single certificate in that file, meaning that using a certificate bundle for certificate verification would not work. This fix allows the driver to read multiple trusted certificates out of a BIO and provision them in the trusted certificate store. --- src/ssl/ssl_openssl_impl.cpp | 41 ++++++++++++++++++------------------ 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/src/ssl/ssl_openssl_impl.cpp b/src/ssl/ssl_openssl_impl.cpp index 8a5cda6db..3c69b90f9 100644 --- a/src/ssl/ssl_openssl_impl.cpp +++ b/src/ssl/ssl_openssl_impl.cpp @@ -228,22 +228,6 @@ static int SSL_CTX_use_certificate_chain_bio(SSL_CTX* ctx, BIO* in) { return ret; } -static X509* load_cert(const char* cert, size_t cert_size) { - BIO* bio = BIO_new_mem_buf(const_cast(cert), cert_size); - if (bio == NULL) { - return NULL; - } - - X509* x509 = PEM_read_bio_X509(bio, NULL, pem_password_callback, NULL); - if (x509 == NULL) { - ssl_log_errors("Unable to load certificate"); - } - - BIO_free_all(bio); - - return x509; -} - static EVP_PKEY* load_key(const char* key, size_t key_size, const char* password) { BIO* bio = BIO_new_mem_buf(const_cast(key), key_size); if (bio == NULL) { @@ -556,13 +540,30 @@ SslSession* OpenSslContext::create_session(const Address& address, const String& } CassError OpenSslContext::add_trusted_cert(const char* cert, size_t cert_length) { - X509* x509 = load_cert(cert, cert_length); - if (x509 == NULL) { + BIO* bio = BIO_new_mem_buf(const_cast(cert), cert_length); + if (bio == NULL) { return CASS_ERROR_SSL_INVALID_CERT; } - X509_STORE_add_cert(trusted_store_, x509); - X509_free(x509); + int num_certs = 0; + + // Iterate over the bio, reading out as many certificates as possible. + for (X509* cert = PEM_read_bio_X509(bio, NULL, pem_password_callback, NULL); + cert != NULL; + cert = PEM_read_bio_X509(bio, NULL, pem_password_callback, NULL)) + { + X509_STORE_add_cert(trusted_store_, cert); + X509_free(cert); + num_certs++; + } + + BIO_free_all(bio); + + // If no certificates were read from the bio, that is an error. + if (num_certs == 0) { + ssl_log_errors("Unable to load certificate(s)"); + return CASS_ERROR_SSL_INVALID_CERT; + } return CASS_OK; }