From 195f89a7c59d2ce92790debad191b97f410def0c Mon Sep 17 00:00:00 2001 From: Charles-Henri Bruyand Date: Thu, 27 Jun 2024 15:02:39 +0200 Subject: [PATCH] dnsdist: move the setTicketsKeyAddedHook to a unique callback for every tls context --- pdns/dnsdistdist/dnsdist-doh-common.cc | 5 --- pdns/dnsdistdist/dnsdist-doh-common.hh | 5 --- pdns/dnsdistdist/dnsdist-lua-hooks.cc | 17 +++++++++ pdns/dnsdistdist/dnsdist-lua-hooks.hh | 3 ++ pdns/dnsdistdist/dnsdist-lua.cc | 23 ----------- pdns/dnsdistdist/docs/reference/config.rst | 44 ++++++---------------- pdns/libssl.cc | 10 ++--- pdns/libssl.hh | 4 -- pdns/tcpiohandler.cc | 36 +++++++----------- pdns/tcpiohandler.hh | 29 ++++++++------ regression-tests.dnsdist/test_TLS.py | 2 +- 11 files changed, 65 insertions(+), 113 deletions(-) diff --git a/pdns/dnsdistdist/dnsdist-doh-common.cc b/pdns/dnsdistdist/dnsdist-doh-common.cc index fc66f286bb34..dcbd183d7cef 100644 --- a/pdns/dnsdistdist/dnsdist-doh-common.cc +++ b/pdns/dnsdistdist/dnsdist-doh-common.cc @@ -94,11 +94,6 @@ void DOHFrontend::rotateTicketsKey(time_t now) return d_tlsContext.rotateTicketsKey(now); } -void DOHFrontend::setTicketsKeyAddedHook(const dnsdist_tickets_key_added_hook& hook) -{ - return d_tlsContext.setTicketsKeyAddedHook(hook); -} - void DOHFrontend::loadTicketsKeys(const std::string& keyFile) { return d_tlsContext.loadTicketsKeys(keyFile); diff --git a/pdns/dnsdistdist/dnsdist-doh-common.hh b/pdns/dnsdistdist/dnsdist-doh-common.hh index 82ef70f83b36..0dc714df23a3 100644 --- a/pdns/dnsdistdist/dnsdist-doh-common.hh +++ b/pdns/dnsdistdist/dnsdist-doh-common.hh @@ -162,10 +162,6 @@ struct DOHFrontend { } - virtual void setTicketsKeyAddedHook(const dnsdist_tickets_key_added_hook& /* hook */) - { - } - virtual void loadTicketsKeys(const std::string& /* keyFile */) { } @@ -189,7 +185,6 @@ struct DOHFrontend virtual void setup(); virtual void reloadCertificates(); - virtual void setTicketsKeyAddedHook(const dnsdist_tickets_key_added_hook& hook); virtual void rotateTicketsKey(time_t now); virtual void loadTicketsKeys(const std::string& keyFile); virtual void handleTicketsKeyRotation(); diff --git a/pdns/dnsdistdist/dnsdist-lua-hooks.cc b/pdns/dnsdistdist/dnsdist-lua-hooks.cc index c5ccb48915c1..621e73451205 100644 --- a/pdns/dnsdistdist/dnsdist-lua-hooks.cc +++ b/pdns/dnsdistdist/dnsdist-lua-hooks.cc @@ -2,6 +2,7 @@ #include "dnsdist-lua-hooks.hh" #include "dnsdist-lua.hh" #include "lock.hh" +#include "tcpiohandler.hh" namespace dnsdist::lua::hooks { @@ -26,12 +27,28 @@ void clearMaintenanceHooks() s_maintenanceHooks.lock()->clear(); } +void setTicketsKeyAddedHook(const LuaContext& context, const TicketsKeyAddedHook& hook) +{ + TLSCtx::setTicketsKeyAddedHook([hook](const std::string& key) { + try { + hook(key.c_str(), key.size()); + } + catch (const std::exception& exp) { + warnlog("Error calling the Lua hook after new tickets key has been added", exp.what()); + } + }); +} + void setupLuaHooks(LuaContext& luaCtx) { luaCtx.writeFunction("addMaintenanceCallback", [&luaCtx](const MaintenanceCallback& callback) { setLuaSideEffect(); addMaintenanceCallback(luaCtx, callback); }); + luaCtx.writeFunction("setTicketsKeyAddedHook", [&luaCtx](const TicketsKeyAddedHook& hook) { + setLuaSideEffect(); + setTicketsKeyAddedHook(luaCtx, hook); + }); } } diff --git a/pdns/dnsdistdist/dnsdist-lua-hooks.hh b/pdns/dnsdistdist/dnsdist-lua-hooks.hh index 11a9084883ee..8cbb7c903ae9 100644 --- a/pdns/dnsdistdist/dnsdist-lua-hooks.hh +++ b/pdns/dnsdistdist/dnsdist-lua-hooks.hh @@ -28,8 +28,11 @@ class LuaContext; namespace dnsdist::lua::hooks { using MaintenanceCallback = std::function; +using TicketsKeyAddedHook = std::function; + void runMaintenanceHooks(const LuaContext& context); void addMaintenanceCallback(const LuaContext& context, MaintenanceCallback callback); +void setTicketsKeyAddedHook(const LuaContext& context, const TicketsKeyAddedHook& hook); void clearMaintenanceHooks(); void setupLuaHooks(LuaContext& luaCtx); } diff --git a/pdns/dnsdistdist/dnsdist-lua.cc b/pdns/dnsdistdist/dnsdist-lua.cc index bb5edcd22474..c526a93cc2d6 100644 --- a/pdns/dnsdistdist/dnsdist-lua.cc +++ b/pdns/dnsdistdist/dnsdist-lua.cc @@ -3011,13 +3011,6 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) } }); - - luaCtx.registerFunction::*)(const dnsdist_tickets_key_added_hook&)>("setTicketsKeyAddedHook", [](const std::shared_ptr& frontend, const dnsdist_tickets_key_added_hook& hook) { - if (frontend != nullptr) { - frontend->setTicketsKeyAddedHook(hook); - } - }); - luaCtx.registerFunction::*)(const LuaArray>&)>("setResponsesMap", [](const std::shared_ptr& frontend, const LuaArray>& map) { if (frontend != nullptr) { auto newMap = std::make_shared>>(); @@ -3215,12 +3208,6 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) } }); - luaCtx.registerFunction::*)(const dnsdist_tickets_key_added_hook&)>("setTicketsKeyAddedHook", [](const std::shared_ptr& frontend, const dnsdist_tickets_key_added_hook& hook) { - if (frontend != nullptr) { - frontend->setTicketsKeyAddedHook(hook); - } - }); - luaCtx.registerFunction::*)(const std::string&)>("loadTicketsKeys", [](std::shared_ptr& ctx, const std::string& file) { if (ctx != nullptr) { ctx->loadTicketsKeys(file); @@ -3234,16 +3221,6 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) return frontend->d_addr.toStringWithPort(); }); - luaCtx.registerFunction::*)(const dnsdist_tickets_key_added_hook&)>("setTicketsKeyAddedHook", [](const std::shared_ptr& frontend, const dnsdist_tickets_key_added_hook& hook) { - if (frontend == nullptr) { - return; - } - auto ctx = frontend->getContext(); - if (ctx) { - ctx->setTicketsKeyAddedHook(hook); - } - }); - luaCtx.registerFunction::*)()>("rotateTicketsKey", [](std::shared_ptr& frontend) { if (frontend == nullptr) { return; diff --git a/pdns/dnsdistdist/docs/reference/config.rst b/pdns/dnsdistdist/docs/reference/config.rst index db66c1eb8267..80ad8ab46546 100644 --- a/pdns/dnsdistdist/docs/reference/config.rst +++ b/pdns/dnsdistdist/docs/reference/config.rst @@ -2173,6 +2173,17 @@ Other functions Code is supplied as a string, not as a function object. Note that this function does nothing in 'client' or 'config-check' modes. +.. function:: setTicketsKeyAddedHook(callback) + + .. versionadded:: 1.9.0 + + Set a Lua function that will be called everytime a new tickets key is added. The function receives: + + * the key content as a string + * the keylen as an integer + + See :doc:`../advanced/tls-sessions-management` for more information. + .. function:: submitToMainThread(cmd, dict) .. versionadded:: 1.8.0 @@ -2322,17 +2333,6 @@ DOHFrontend Replace the current TLS tickets key by a new random one. - .. method:: DOHFrontend:setTicketsKeyAddedHook(callback) - - .. versionadded:: 1.9.0 - - Set a Lua function that will be called everytime a new tickets key is added. The function receives: - - * the key content as a string - * the keylen as an integer - - See :doc:`../advanced/tls-sessions-management` for more information. - .. method:: DOHFrontend:setResponsesMap(rules) Set a list of HTTP response rules allowing to intercept HTTP queries very early, before the DNS payload has been processed, and send custom responses including error pages, redirects and static content. @@ -2475,17 +2475,6 @@ TLSContext Replace the current TLS tickets key by a new random one. - .. method:: TLSContext:setTicketsKeyAddedHook(callback) - - .. versionadded:: 1.9.0 - - Set a Lua function that will be called everytime a new tickets key is added. The function receives: - - * the key content as a string - * the keylen as an integer - - See :doc:`../advanced/tls-sessions-management` for more information. - TLSFrontend ~~~~~~~~~~~ @@ -2527,17 +2516,6 @@ TLSFrontend Replace the current TLS tickets key by a new random one. - .. method:: TLSFrontend:setTicketsKeyAddedHook(callback) - - .. versionadded:: 1.9.0 - - Set a Lua function that will be called everytime a new tickets key is added. The function receives: - - * the key content as a string - * the keylen as an integer - - See :doc:`../advanced/tls-sessions-management` for more information. - EDNS on Self-generated answers ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/pdns/libssl.cc b/pdns/libssl.cc index f72edfaca270..cd9ad076fe94 100644 --- a/pdns/libssl.cc +++ b/pdns/libssl.cc @@ -42,6 +42,7 @@ #undef CERT #include "misc.hh" +#include "tcpiohandler.hh" #if (OPENSSL_VERSION_NUMBER < 0x1010000fL || (defined LIBRESSL_VERSION_NUMBER) && LIBRESSL_VERSION_NUMBER < 0x2090100fL) /* OpenSSL < 1.1.0 needs support for threading/locking in the calling application. */ @@ -631,18 +632,13 @@ OpenSSLTLSTicketKeysRing::~OpenSSLTLSTicketKeysRing() = default; void OpenSSLTLSTicketKeysRing::addKey(std::shared_ptr&& newKey) { d_ticketKeys.write_lock()->push_front(std::move(newKey)); - if (d_ticketsKeyAddedHook) { + if (TLSCtx::hasTicketsKeyAddedHook()) { auto key = d_ticketKeys.read_lock()->front(); auto keyContent = key->content(); - d_ticketsKeyAddedHook(keyContent.c_str(), keyContent.size()); + TLSCtx::getTicketsKeyAddedHook()(keyContent); } } -void OpenSSLTLSTicketKeysRing::setTicketsKeyAddedHook(const dnsdist_tickets_key_added_hook& hook) -{ - d_ticketsKeyAddedHook = hook; -} - std::shared_ptr OpenSSLTLSTicketKeysRing::getEncryptionKey() { return d_ticketKeys.read_lock()->front(); diff --git a/pdns/libssl.hh b/pdns/libssl.hh index d0ed6a96bca2..c1ed2067407e 100644 --- a/pdns/libssl.hh +++ b/pdns/libssl.hh @@ -112,8 +112,6 @@ private: unsigned char d_hmacKey[TLS_TICKETS_MAC_KEY_SIZE]; }; -using dnsdist_tickets_key_added_hook = std::function; - class OpenSSLTLSTicketKeysRing { public: @@ -124,11 +122,9 @@ public: size_t getKeysCount(); void loadTicketsKeys(const std::string& keyFile); void rotateTicketsKey(time_t now); - void setTicketsKeyAddedHook(const dnsdist_tickets_key_added_hook& hook); private: void addKey(std::shared_ptr&& newKey); - dnsdist_tickets_key_added_hook d_ticketsKeyAddedHook; SharedLockGuarded > > d_ticketKeys; }; diff --git a/pdns/tcpiohandler.cc b/pdns/tcpiohandler.cc index 87391ba2ab18..1fb91ef5e0ab 100644 --- a/pdns/tcpiohandler.cc +++ b/pdns/tcpiohandler.cc @@ -22,6 +22,7 @@ const bool TCPIOHandler::s_disableConnectForUnitTests = false; #include "libssl.hh" +dnsdist_tickets_key_added_hook TLSCtx::s_ticketsKeyAddedHook{nullptr}; class OpenSSLFrontendContext { @@ -813,11 +814,6 @@ class OpenSSLTLSIOCtx: public TLSCtx } } - void setTicketsKeyAddedHook(const dnsdist_tickets_key_added_hook& hook) override - { - d_feContext->d_ticketKeys.setTicketsKeyAddedHook(hook); - } - void loadTicketsKeys(const std::string& keyFile) final { d_feContext->d_ticketKeys.loadTicketsKeys(keyFile); @@ -1743,19 +1739,12 @@ class GnuTLSIOCtx: public TLSCtx return connection; } - void setTicketsKeyAddedHook(const dnsdist_tickets_key_added_hook& hook) override - { - d_ticketsKeyAddedHook = hook; - } - - void rotateTicketsKey(time_t now) override + void addTicketsKey(time_t now, std::shared_ptr&& newKey) { if (!d_enableTickets) { return; } - auto newKey = std::make_shared(); - { *(d_ticketsKey.write_lock()) = std::move(newKey); } @@ -1764,13 +1753,21 @@ class GnuTLSIOCtx: public TLSCtx d_ticketsKeyNextRotation = now + d_ticketsKeyRotationDelay; } - if (d_ticketsKeyAddedHook) { + if (TLSCtx::hasTicketsKeyAddedHook()) { auto ticketsKey = *(d_ticketsKey.read_lock()); auto content = ticketsKey->content(); - d_ticketsKeyAddedHook(content.c_str(), content.size()); + TLSCtx::getTicketsKeyAddedHook()(content); } } + void rotateTicketsKey(time_t now) override + { + if (!d_enableTickets) { + return; + } + auto newKey = std::make_shared(); + addTicketsKey(now, std::move(newKey)); + } void loadTicketsKeys(const std::string& file) final { if (!d_enableTickets) { @@ -1778,13 +1775,7 @@ class GnuTLSIOCtx: public TLSCtx } auto newKey = std::make_shared(file); - { - *(d_ticketsKey.write_lock()) = std::move(newKey); - } - - if (d_ticketsKeyRotationDelay > 0) { - d_ticketsKeyNextRotation = time(nullptr) + d_ticketsKeyRotationDelay; - } + addTicketsKey(time(nullptr), std::move(newKey)); } size_t getTicketsKeysCount() override @@ -1816,7 +1807,6 @@ class GnuTLSIOCtx: public TLSCtx SharedLockGuarded> d_ticketsKey{nullptr}; bool d_enableTickets{true}; bool d_validateCerts{true}; - dnsdist_tickets_key_added_hook d_ticketsKeyAddedHook; }; #endif /* HAVE_GNUTLS */ diff --git a/pdns/tcpiohandler.hh b/pdns/tcpiohandler.hh index c592701eedf9..59817beefe5d 100644 --- a/pdns/tcpiohandler.hh +++ b/pdns/tcpiohandler.hh @@ -66,6 +66,8 @@ protected: bool d_resumedFromInactiveTicketKey{false}; }; +using dnsdist_tickets_key_added_hook = std::function; + class TLSCtx { public: @@ -81,11 +83,6 @@ public: { throw std::runtime_error("This TLS backend does not have the capability to load a tickets key from a file"); } - virtual void setTicketsKeyAddedHook(const dnsdist_tickets_key_added_hook& /* hook */) - { - throw std::runtime_error("This TLS backend does not have the capability to setup a hook for added tickets keys"); - } - void handleTicketsKeyRotation(time_t now) { if (d_ticketsKeyRotationDelay != 0 && now > d_ticketsKeyNextRotation) { @@ -128,10 +125,25 @@ public: return false; } + static void setTicketsKeyAddedHook(const dnsdist_tickets_key_added_hook& hook) + { + TLSCtx::s_ticketsKeyAddedHook = hook; + } + static const dnsdist_tickets_key_added_hook& getTicketsKeyAddedHook() + { + return TLSCtx::s_ticketsKeyAddedHook; + } + static bool hasTicketsKeyAddedHook() + { + return TLSCtx::s_ticketsKeyAddedHook != nullptr; + } protected: std::atomic_flag d_rotatingTicketsKey; std::atomic d_ticketsKeyNextRotation{0}; time_t d_ticketsKeyRotationDelay{0}; + +private: + static dnsdist_tickets_key_added_hook s_ticketsKeyAddedHook; }; class TLSFrontend @@ -156,13 +168,6 @@ public: } } - void setTicketsKeyAddedHook(const dnsdist_tickets_key_added_hook& hook) - { - if (d_ctx != nullptr) { - d_ctx->setTicketsKeyAddedHook(hook); - } - } - void loadTicketsKeys(const std::string& file) { if (d_ctx != nullptr) { diff --git a/regression-tests.dnsdist/test_TLS.py b/regression-tests.dnsdist/test_TLS.py index 6138c3e6363e..27c2de52fe19 100644 --- a/regression-tests.dnsdist/test_TLS.py +++ b/regression-tests.dnsdist/test_TLS.py @@ -547,7 +547,7 @@ def testLuaThreadCounter(self): """ LuaThread: Test the lua newThread interface """ - self.sendConsoleCommand('getTLSFrontend(0):setTicketsKeyAddedHook(keyAddedCallback)'); + self.sendConsoleCommand('setTicketsKeyAddedHook(keyAddedCallback)'); called = self.sendConsoleCommand('callbackCalled') self.assertEqual(int(called), 0) self.sendConsoleCommand("getTLSFrontend(0):rotateTicketsKey()")