From 300deebf41d2da96701fe29c0faa8025b7efa120 Mon Sep 17 00:00:00 2001 From: Lang Hames <lhames@gmail.com> Date: Tue, 17 Dec 2024 17:27:52 +1100 Subject: [PATCH] [ORC] Make LazyReexportsManager implement ResourceManager. This ensures that the reexports mappings are cleared when the resource tracker associated with each mapping is removed. --- llvm/include/llvm/ExecutionEngine/Orc/Core.h | 5 ++ .../llvm/ExecutionEngine/Orc/LazyReexports.h | 9 ++- .../lib/ExecutionEngine/Orc/LazyReexports.cpp | 58 +++++++++++---- .../Orc/LazyCallThroughAndReexportsTest.cpp | 74 +++++++++++++++++++ 4 files changed, 129 insertions(+), 17 deletions(-) diff --git a/llvm/include/llvm/ExecutionEngine/Orc/Core.h b/llvm/include/llvm/ExecutionEngine/Orc/Core.h index 7f75d799cab6a..2788932ca4bcb 100644 --- a/llvm/include/llvm/ExecutionEngine/Orc/Core.h +++ b/llvm/include/llvm/ExecutionEngine/Orc/Core.h @@ -200,6 +200,11 @@ class SymbolLookupSet { SymbolLookupSet() = default; + SymbolLookupSet(std::initializer_list<value_type> Elems) { + for (auto &E : Elems) + Symbols.push_back(std::move(E)); + } + explicit SymbolLookupSet( SymbolStringPtr Name, SymbolLookupFlags Flags = SymbolLookupFlags::RequiredSymbol) { diff --git a/llvm/include/llvm/ExecutionEngine/Orc/LazyReexports.h b/llvm/include/llvm/ExecutionEngine/Orc/LazyReexports.h index 0dcf646b12dd8..cc9c664d0e7c0 100644 --- a/llvm/include/llvm/ExecutionEngine/Orc/LazyReexports.h +++ b/llvm/include/llvm/ExecutionEngine/Orc/LazyReexports.h @@ -173,7 +173,7 @@ lazyReexports(LazyCallThroughManager &LCTManager, LCTManager, RSManager, SourceJD, std::move(CallableAliases), SrcJDLoc); } -class LazyReexportsManager { +class LazyReexportsManager : public ResourceManager { friend std::unique_ptr<MaterializationUnit> lazyReexports(LazyReexportsManager &, SymbolAliasMap); @@ -194,6 +194,10 @@ class LazyReexportsManager { LazyReexportsManager(LazyReexportsManager &&) = delete; LazyReexportsManager &operator=(LazyReexportsManager &&) = delete; + Error handleRemoveResources(JITDylib &JD, ResourceKey K) override; + void handleTransferResources(JITDylib &JD, ResourceKey DstK, + ResourceKey SrcK) override; + private: struct CallThroughInfo { SymbolStringPtr Name; @@ -222,10 +226,11 @@ class LazyReexportsManager { Expected<std::vector<ExecutorSymbolDef>> ReentryPoints); void resolve(ResolveSendResultFn SendResult, ExecutorAddr ReentryStubAddr); + ExecutionSession &ES; EmitTrampolinesFn EmitTrampolines; RedirectableSymbolManager &RSMgr; - std::mutex M; + DenseMap<ResourceKey, ExecutorAddr> KeyToReentryAddr; DenseMap<ExecutorAddr, CallThroughInfo> CallThroughs; }; diff --git a/llvm/lib/ExecutionEngine/Orc/LazyReexports.cpp b/llvm/lib/ExecutionEngine/Orc/LazyReexports.cpp index 7a7e5d13ce03f..6e1e3746bfa24 100644 --- a/llvm/lib/ExecutionEngine/Orc/LazyReexports.cpp +++ b/llvm/lib/ExecutionEngine/Orc/LazyReexports.cpp @@ -292,16 +292,39 @@ LazyReexportsManager::Create(EmitTrampolinesFn EmitTrampolines, return std::move(LRM); } +Error LazyReexportsManager::handleRemoveResources(JITDylib &JD, ResourceKey K) { + JD.getExecutionSession().runSessionLocked([&]() { + auto I = KeyToReentryAddr.find(K); + if (I != KeyToReentryAddr.end()) { + auto ReentryAddr = I->second; + CallThroughs.erase(ReentryAddr); + KeyToReentryAddr.erase(I); + } + }); + return Error::success(); +} + +void LazyReexportsManager::handleTransferResources(JITDylib &JD, + ResourceKey DstK, + ResourceKey SrcK) { + auto I = KeyToReentryAddr.find(SrcK); + if (I != KeyToReentryAddr.end()) { + auto ReentryAddr = I->second; + KeyToReentryAddr.erase(I); + KeyToReentryAddr[DstK] = ReentryAddr; + } +} + LazyReexportsManager::LazyReexportsManager(EmitTrampolinesFn EmitTrampolines, RedirectableSymbolManager &RSMgr, JITDylib &PlatformJD, Error &Err) - : EmitTrampolines(std::move(EmitTrampolines)), RSMgr(RSMgr) { + : ES(PlatformJD.getExecutionSession()), + EmitTrampolines(std::move(EmitTrampolines)), RSMgr(RSMgr) { using namespace shared; ErrorAsOutParameter _(&Err); - auto &ES = PlatformJD.getExecutionSession(); ExecutionSession::JITDispatchHandlerAssociationMap WFs; WFs[ES.intern("__orc_rt_resolve_tag")] = @@ -345,15 +368,22 @@ void LazyReexportsManager::emitRedirectableSymbols( // Bind entry points to names. SymbolMap Redirs; - { - std::lock_guard<std::mutex> Lock(M); - size_t I = 0; - for (auto &[Name, AI] : Reexports) { - const auto &ReentryPoint = (*ReentryPoints)[I++]; - Redirs[Name] = ReentryPoint; - CallThroughs[ReentryPoint.getAddress()] = {Name, AI.Aliasee, - &MR->getTargetJITDylib()}; - } + size_t I = 0; + for (auto &[Name, AI] : Reexports) + Redirs[Name] = (*ReentryPoints)[I++]; + + I = 0; + if (auto Err = MR->withResourceKeyDo([&](ResourceKey K) { + for (auto &[Name, AI] : Reexports) { + const auto &ReentryPoint = (*ReentryPoints)[I++]; + CallThroughs[ReentryPoint.getAddress()] = {Name, AI.Aliasee, + &MR->getTargetJITDylib()}; + KeyToReentryAddr[K] = ReentryPoint.getAddress(); + } + })) { + MR->getExecutionSession().reportError(std::move(Err)); + MR->failMaterialization(); + return; } RSMgr.emitRedirectableSymbols(std::move(MR), std::move(Redirs)); @@ -364,9 +394,7 @@ void LazyReexportsManager::resolve(ResolveSendResultFn SendResult, CallThroughInfo LandingInfo; - { - std::lock_guard<std::mutex> Lock(M); - + ES.runSessionLocked([&]() { auto I = CallThroughs.find(ReentryStubAddr); if (I == CallThroughs.end()) return SendResult(make_error<StringError>( @@ -374,7 +402,7 @@ void LazyReexportsManager::resolve(ResolveSendResultFn SendResult, " not registered", inconvertibleErrorCode())); LandingInfo = I->second; - } + }); SymbolInstance LandingSym(LandingInfo.JD, std::move(LandingInfo.BodyName)); LandingSym.lookupAsync([this, JD = std::move(LandingInfo.JD), diff --git a/llvm/unittests/ExecutionEngine/Orc/LazyCallThroughAndReexportsTest.cpp b/llvm/unittests/ExecutionEngine/Orc/LazyCallThroughAndReexportsTest.cpp index 7f367cfd58739..6bb244c20e8e2 100644 --- a/llvm/unittests/ExecutionEngine/Orc/LazyCallThroughAndReexportsTest.cpp +++ b/llvm/unittests/ExecutionEngine/Orc/LazyCallThroughAndReexportsTest.cpp @@ -1,6 +1,11 @@ #include "OrcTestCommon.h" +#include "llvm/ExecutionEngine/Orc/AbsoluteSymbols.h" +#include "llvm/ExecutionEngine/Orc/JITLinkRedirectableSymbolManager.h" +#include "llvm/ExecutionEngine/Orc/JITLinkReentryTrampolines.h" #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" +#include "llvm/ExecutionEngine/Orc/LLJIT.h" #include "llvm/ExecutionEngine/Orc/LazyReexports.h" +#include "llvm/Testing/Support/Error.h" #include "gtest/gtest.h" using namespace llvm; @@ -70,3 +75,72 @@ TEST_F(LazyReexportsTest, BasicLocalCallThroughManagerOperation) { << "CallThrough should have generated exactly one 'NotifyResolved' call"; EXPECT_EQ(Result, 42) << "Failed to call through to target"; } + +static void *noReentry(void *) { abort(); } + +TEST(JITLinkLazyReexportsTest, Basics) { + OrcNativeTarget::initialize(); + + auto J = LLJITBuilder().create(); + if (!J) { + dbgs() << toString(J.takeError()) << "\n"; + // consumeError(J.takeError()); + GTEST_SKIP(); + } + if (!isa<ObjectLinkingLayer>((*J)->getObjLinkingLayer())) + GTEST_SKIP(); + + auto &OLL = cast<ObjectLinkingLayer>((*J)->getObjLinkingLayer()); + + auto RSMgr = JITLinkRedirectableSymbolManager::Create(OLL); + if (!RSMgr) { + dbgs() << "Boom for RSMgr\n"; + consumeError(RSMgr.takeError()); + GTEST_SKIP(); + } + + auto &ES = (*J)->getExecutionSession(); + + auto &JD = ES.createBareJITDylib("JD"); + cantFail(JD.define(absoluteSymbols( + {{ES.intern("__orc_rt_reentry"), + {ExecutorAddr::fromPtr(&noReentry), + JITSymbolFlags::Exported | JITSymbolFlags::Callable}}}))); + + auto LRMgr = createJITLinkLazyReexportsManager(OLL, **RSMgr, JD); + if (!LRMgr) { + dbgs() << "Boom for LRMgr\n"; + consumeError(LRMgr.takeError()); + GTEST_SKIP(); + } + + auto Foo = ES.intern("foo"); + auto Bar = ES.intern("bar"); + + auto RT = JD.createResourceTracker(); + cantFail(JD.define( + lazyReexports( + **LRMgr, + {{Foo, {Bar, JITSymbolFlags::Exported | JITSymbolFlags::Callable}}}), + RT)); + + // Check flags after adding Foo -> Bar lazy reexport. + auto SF = cantFail( + ES.lookupFlags(LookupKind::Static, makeJITDylibSearchOrder(&JD), + {{Foo, SymbolLookupFlags::WeaklyReferencedSymbol}})); + EXPECT_EQ(SF.size(), 1U); + EXPECT_TRUE(SF.count(Foo)); + EXPECT_EQ(SF[Foo], JITSymbolFlags::Exported | JITSymbolFlags::Callable); + + // Remove reexport without running it. + if (auto Err = RT->remove()) { + EXPECT_THAT_ERROR(std::move(Err), Succeeded()); + return; + } + + // Check flags after adding Foo -> Bar lazy reexport. + SF = cantFail( + ES.lookupFlags(LookupKind::Static, makeJITDylibSearchOrder(&JD), + {{Foo, SymbolLookupFlags::WeaklyReferencedSymbol}})); + EXPECT_EQ(SF.size(), 0U); +}