diff --git a/pdns/dnsdistdist/dnsdist-doh-common.cc b/pdns/dnsdistdist/dnsdist-doh-common.cc index dcbd183d7cef..fc66f286bb34 100644 --- a/pdns/dnsdistdist/dnsdist-doh-common.cc +++ b/pdns/dnsdistdist/dnsdist-doh-common.cc @@ -94,6 +94,11 @@ void DOHFrontend::rotateTicketsKey(time_t now) return d_tlsContext.rotateTicketsKey(now); } +void DOHFrontend::setTicketsKeyAddedHook(const dnsdist_tickets_key_added_hook& hook) +{ + return d_tlsContext.setTicketsKeyAddedHook(hook); +} + void DOHFrontend::loadTicketsKeys(const std::string& keyFile) { return d_tlsContext.loadTicketsKeys(keyFile); diff --git a/pdns/dnsdistdist/dnsdist-doh-common.hh b/pdns/dnsdistdist/dnsdist-doh-common.hh index 0dc714df23a3..82ef70f83b36 100644 --- a/pdns/dnsdistdist/dnsdist-doh-common.hh +++ b/pdns/dnsdistdist/dnsdist-doh-common.hh @@ -162,6 +162,10 @@ struct DOHFrontend { } + virtual void setTicketsKeyAddedHook(const dnsdist_tickets_key_added_hook& /* hook */) + { + } + virtual void loadTicketsKeys(const std::string& /* keyFile */) { } @@ -185,6 +189,7 @@ struct DOHFrontend virtual void setup(); virtual void reloadCertificates(); + virtual void setTicketsKeyAddedHook(const dnsdist_tickets_key_added_hook& hook); virtual void rotateTicketsKey(time_t now); virtual void loadTicketsKeys(const std::string& keyFile); virtual void handleTicketsKeyRotation(); diff --git a/pdns/dnsdistdist/dnsdist-lua.cc b/pdns/dnsdistdist/dnsdist-lua.cc index c526a93cc2d6..bb5edcd22474 100644 --- a/pdns/dnsdistdist/dnsdist-lua.cc +++ b/pdns/dnsdistdist/dnsdist-lua.cc @@ -3011,6 +3011,13 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) } }); + + luaCtx.registerFunction::*)(const dnsdist_tickets_key_added_hook&)>("setTicketsKeyAddedHook", [](const std::shared_ptr& frontend, const dnsdist_tickets_key_added_hook& hook) { + if (frontend != nullptr) { + frontend->setTicketsKeyAddedHook(hook); + } + }); + luaCtx.registerFunction::*)(const LuaArray>&)>("setResponsesMap", [](const std::shared_ptr& frontend, const LuaArray>& map) { if (frontend != nullptr) { auto newMap = std::make_shared>>(); @@ -3208,6 +3215,12 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) } }); + luaCtx.registerFunction::*)(const dnsdist_tickets_key_added_hook&)>("setTicketsKeyAddedHook", [](const std::shared_ptr& frontend, const dnsdist_tickets_key_added_hook& hook) { + if (frontend != nullptr) { + frontend->setTicketsKeyAddedHook(hook); + } + }); + luaCtx.registerFunction::*)(const std::string&)>("loadTicketsKeys", [](std::shared_ptr& ctx, const std::string& file) { if (ctx != nullptr) { ctx->loadTicketsKeys(file); @@ -3221,6 +3234,16 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) return frontend->d_addr.toStringWithPort(); }); + luaCtx.registerFunction::*)(const dnsdist_tickets_key_added_hook&)>("setTicketsKeyAddedHook", [](const std::shared_ptr& frontend, const dnsdist_tickets_key_added_hook& hook) { + if (frontend == nullptr) { + return; + } + auto ctx = frontend->getContext(); + if (ctx) { + ctx->setTicketsKeyAddedHook(hook); + } + }); + luaCtx.registerFunction::*)()>("rotateTicketsKey", [](std::shared_ptr& frontend) { if (frontend == nullptr) { return; diff --git a/pdns/dnsdistdist/docs/reference/config.rst b/pdns/dnsdistdist/docs/reference/config.rst index 7473624276e6..db66c1eb8267 100644 --- a/pdns/dnsdistdist/docs/reference/config.rst +++ b/pdns/dnsdistdist/docs/reference/config.rst @@ -2322,6 +2322,17 @@ DOHFrontend Replace the current TLS tickets key by a new random one. + .. method:: DOHFrontend:setTicketsKeyAddedHook(callback) + + .. versionadded:: 1.9.0 + + Set a Lua function that will be called everytime a new tickets key is added. The function receives: + + * the key content as a string + * the keylen as an integer + + See :doc:`../advanced/tls-sessions-management` for more information. + .. method:: DOHFrontend:setResponsesMap(rules) Set a list of HTTP response rules allowing to intercept HTTP queries very early, before the DNS payload has been processed, and send custom responses including error pages, redirects and static content. @@ -2464,6 +2475,17 @@ TLSContext Replace the current TLS tickets key by a new random one. + .. method:: TLSContext:setTicketsKeyAddedHook(callback) + + .. versionadded:: 1.9.0 + + Set a Lua function that will be called everytime a new tickets key is added. The function receives: + + * the key content as a string + * the keylen as an integer + + See :doc:`../advanced/tls-sessions-management` for more information. + TLSFrontend ~~~~~~~~~~~ @@ -2505,6 +2527,17 @@ TLSFrontend Replace the current TLS tickets key by a new random one. + .. method:: TLSFrontend:setTicketsKeyAddedHook(callback) + + .. versionadded:: 1.9.0 + + Set a Lua function that will be called everytime a new tickets key is added. The function receives: + + * the key content as a string + * the keylen as an integer + + See :doc:`../advanced/tls-sessions-management` for more information. + EDNS on Self-generated answers ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/pdns/libssl.cc b/pdns/libssl.cc index 3f657326c432..f72edfaca270 100644 --- a/pdns/libssl.cc +++ b/pdns/libssl.cc @@ -631,6 +631,16 @@ OpenSSLTLSTicketKeysRing::~OpenSSLTLSTicketKeysRing() = default; void OpenSSLTLSTicketKeysRing::addKey(std::shared_ptr&& newKey) { d_ticketKeys.write_lock()->push_front(std::move(newKey)); + if (d_ticketsKeyAddedHook) { + auto key = d_ticketKeys.read_lock()->front(); + auto keyContent = key->content(); + d_ticketsKeyAddedHook(keyContent.c_str(), keyContent.size()); + } +} + +void OpenSSLTLSTicketKeysRing::setTicketsKeyAddedHook(const dnsdist_tickets_key_added_hook& hook) +{ + d_ticketsKeyAddedHook = hook; } std::shared_ptr OpenSSLTLSTicketKeysRing::getEncryptionKey() @@ -737,6 +747,17 @@ bool OpenSSLTLSTicketKey::nameMatches(const unsigned char name[TLS_TICKETS_KEY_N return (memcmp(d_name, name, sizeof(d_name)) == 0); } +std::string OpenSSLTLSTicketKey::content() const +{ + std::string result{}; + result.reserve(TLS_TICKETS_KEY_NAME_SIZE + TLS_TICKETS_CIPHER_KEY_SIZE + TLS_TICKETS_MAC_KEY_SIZE); + result.append(reinterpret_cast(d_name), TLS_TICKETS_KEY_NAME_SIZE); + result.append(reinterpret_cast(d_cipherKey), TLS_TICKETS_CIPHER_KEY_SIZE); + result.append(reinterpret_cast(d_hmacKey), TLS_TICKETS_MAC_KEY_SIZE); + + return result; +} + #if OPENSSL_VERSION_MAJOR >= 3 static const std::string sha256KeyName{"sha256"}; #endif diff --git a/pdns/libssl.hh b/pdns/libssl.hh index 8dd7ff373bf6..d0ed6a96bca2 100644 --- a/pdns/libssl.hh +++ b/pdns/libssl.hh @@ -100,6 +100,7 @@ public: #if OPENSSL_VERSION_MAJOR >= 3 int encrypt(unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char* iv, EVP_CIPHER_CTX* ectx, EVP_MAC_CTX* hctx) const; bool decrypt(const unsigned char* iv, EVP_CIPHER_CTX* ectx, EVP_MAC_CTX* hctx) const; + std::string content() const; #else int encrypt(unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char* iv, EVP_CIPHER_CTX* ectx, HMAC_CTX* hctx) const; bool decrypt(const unsigned char* iv, EVP_CIPHER_CTX* ectx, HMAC_CTX* hctx) const; @@ -111,6 +112,8 @@ private: unsigned char d_hmacKey[TLS_TICKETS_MAC_KEY_SIZE]; }; +using dnsdist_tickets_key_added_hook = std::function; + class OpenSSLTLSTicketKeysRing { public: @@ -121,10 +124,11 @@ public: size_t getKeysCount(); void loadTicketsKeys(const std::string& keyFile); void rotateTicketsKey(time_t now); + void setTicketsKeyAddedHook(const dnsdist_tickets_key_added_hook& hook); private: void addKey(std::shared_ptr&& newKey); - + dnsdist_tickets_key_added_hook d_ticketsKeyAddedHook; SharedLockGuarded > > d_ticketKeys; }; diff --git a/pdns/tcpiohandler.cc b/pdns/tcpiohandler.cc index cf82471ba84d..87391ba2ab18 100644 --- a/pdns/tcpiohandler.cc +++ b/pdns/tcpiohandler.cc @@ -813,6 +813,11 @@ class OpenSSLTLSIOCtx: public TLSCtx } } + void setTicketsKeyAddedHook(const dnsdist_tickets_key_added_hook& hook) override + { + d_feContext->d_ticketKeys.setTicketsKeyAddedHook(hook); + } + void loadTicketsKeys(const std::string& keyFile) final { d_feContext->d_ticketKeys.loadTicketsKeys(keyFile); @@ -987,6 +992,14 @@ class GnuTLSTicketsKey throw; } } + std::string content() const + { + std::string result{}; + if (d_key.data != nullptr && d_key.size > 0) { + result.append(reinterpret_cast(d_key.data), d_key.size); + } + return result; + } ~GnuTLSTicketsKey() { @@ -1730,6 +1743,11 @@ class GnuTLSIOCtx: public TLSCtx return connection; } + void setTicketsKeyAddedHook(const dnsdist_tickets_key_added_hook& hook) override + { + d_ticketsKeyAddedHook = hook; + } + void rotateTicketsKey(time_t now) override { if (!d_enableTickets) { @@ -1745,6 +1763,12 @@ class GnuTLSIOCtx: public TLSCtx if (d_ticketsKeyRotationDelay > 0) { d_ticketsKeyNextRotation = now + d_ticketsKeyRotationDelay; } + + if (d_ticketsKeyAddedHook) { + auto ticketsKey = *(d_ticketsKey.read_lock()); + auto content = ticketsKey->content(); + d_ticketsKeyAddedHook(content.c_str(), content.size()); + } } void loadTicketsKeys(const std::string& file) final @@ -1792,6 +1816,7 @@ class GnuTLSIOCtx: public TLSCtx SharedLockGuarded> d_ticketsKey{nullptr}; bool d_enableTickets{true}; bool d_validateCerts{true}; + dnsdist_tickets_key_added_hook d_ticketsKeyAddedHook; }; #endif /* HAVE_GNUTLS */ diff --git a/pdns/tcpiohandler.hh b/pdns/tcpiohandler.hh index 058d10443b71..c592701eedf9 100644 --- a/pdns/tcpiohandler.hh +++ b/pdns/tcpiohandler.hh @@ -81,6 +81,10 @@ public: { throw std::runtime_error("This TLS backend does not have the capability to load a tickets key from a file"); } + virtual void setTicketsKeyAddedHook(const dnsdist_tickets_key_added_hook& /* hook */) + { + throw std::runtime_error("This TLS backend does not have the capability to setup a hook for added tickets keys"); + } void handleTicketsKeyRotation(time_t now) { @@ -152,6 +156,13 @@ public: } } + void setTicketsKeyAddedHook(const dnsdist_tickets_key_added_hook& hook) + { + if (d_ctx != nullptr) { + d_ctx->setTicketsKeyAddedHook(hook); + } + } + void loadTicketsKeys(const std::string& file) { if (d_ctx != nullptr) { diff --git a/regression-tests.dnsdist/test_TLS.py b/regression-tests.dnsdist/test_TLS.py index 9803ed550f96..6138c3e6363e 100644 --- a/regression-tests.dnsdist/test_TLS.py +++ b/regression-tests.dnsdist/test_TLS.py @@ -516,3 +516,40 @@ def setUpClass(cls): cls.startResponders() cls.startDNSDist() cls.setUpSockets() + +class TestTLSTicketsKeyAddedCallback(DNSDistTest): + _consoleKey = DNSDistTest.generateConsoleKey() + _consoleKeyB64 = base64.b64encode(_consoleKey).decode('ascii') + + _serverKey = 'server.key' + _serverCert = 'server.chain' + _serverName = 'tls.tests.dnsdist.org' + _caCert = 'ca.pem' + _tlsServerPort = pickAvailablePort() + _numberOfKeys = 5 + + _config_params = ['_consoleKeyB64', '_consolePort', '_testServerPort', '_tlsServerPort', '_serverCert', '_serverKey'] + _config_template = """ + setKey("%s") + controlSocket("127.0.0.1:%s") + + newServer{address="127.0.0.1:%s"} + addTLSLocal("127.0.0.1:%s", "%s", "%s", { provider="openssl" }) + + callbackCalled = 0 + function keyAddedCallback(key, keyLen) + callbackCalled = keyLen + end + + """ + + def testLuaThreadCounter(self): + """ + LuaThread: Test the lua newThread interface + """ + self.sendConsoleCommand('getTLSFrontend(0):setTicketsKeyAddedHook(keyAddedCallback)'); + called = self.sendConsoleCommand('callbackCalled') + self.assertEqual(int(called), 0) + self.sendConsoleCommand("getTLSFrontend(0):rotateTicketsKey()") + called = self.sendConsoleCommand('callbackCalled') + self.assertGreater(int(called), 0)