From a750a317b1df91de7c1eb10becdae332b7597df4 Mon Sep 17 00:00:00 2001 From: koparasy Date: Mon, 16 Dec 2024 21:38:32 -0800 Subject: [PATCH 1/5] Computes a SHA256 hash per device compilation unit and uses for hashing --- lib/JitEngineDevice.hpp | 153 +++++++++++++++++++++++++++--------- lib/JitEngineDeviceCUDA.cpp | 47 ++++++++--- lib/JitEngineDeviceCUDA.hpp | 7 +- pass/ProteusPass.cpp | 42 ++++++++++ 4 files changed, 196 insertions(+), 53 deletions(-) diff --git a/lib/JitEngineDevice.hpp b/lib/JitEngineDevice.hpp index aa253171..f02a92ba 100644 --- a/lib/JitEngineDevice.hpp +++ b/lib/JitEngineDevice.hpp @@ -13,12 +13,14 @@ #include "llvm/Linker/Linker.h" #include +#include #include #include #include #include #include #include +#include #include #include "llvm/Analysis/TargetTransformInfo.h" @@ -33,6 +35,7 @@ #include "llvm/MC/TargetRegistry.h" #include "llvm/Passes/PassBuilder.h" #include "llvm/Target/TargetMachine.h" +#include "llvm/Transforms/Utils/Cloning.h" #include #include #include @@ -115,6 +118,15 @@ template class JitEngineDevice : protected JitEngine { std::unordered_map ModuleIdToFatBinary; DenseMap HandleToBinaryInfo; DenseMap KernelToHandleMap; + + SmallVector>>> + SHA256HashWithBitcodes; + + DenseMap> LinkedLLVMIRModules; + DenseMap KernelToBitcodeIndex; + /* @Brief After proteus initialization contains all kernels annotathed with + * proteus */ + DenseSet ProteusAnnotatedKernels; SmallVector GlobalLinkedModuleIds; SmallPtrSet GlobalLinkedBinaries; @@ -280,8 +292,15 @@ template class JitEngineDevice : protected JitEngine { Image); } - std::unique_ptr extractDeviceBitcode(StringRef KernelName, - void *Kernel) { + std::unique_ptr + createLinkedModule(ArrayRef> LinkedModules, + StringRef KernelName) { + TIMESCOPE(__FUNCTION__) + return static_cast(*this).createLinkedModule(LinkedModules, + KernelName); + } + + int extractDeviceBitcode(StringRef KernelName, void *Kernel) { TIMESCOPE(__FUNCTION__) return static_cast(*this).extractDeviceBitcode(KernelName, Kernel); } @@ -299,18 +318,33 @@ template class JitEngineDevice : protected JitEngine { std::unordered_map &VarNameToDevPtr); protected: - JitEngineDevice() {} + JitEngineDevice() { ProteusCtx = std::make_unique(); } ~JitEngineDevice() { CodeCache.printStats(); StorageCache.printStats(); + // Note: We manually clear or unique_ptr to Modules before the destructor + // releases the ProteusCtx. + // + // Explicitly clear the LinkedLLVMIRModules + LinkedLLVMIRModules.clear(); + + // Explicitly clear SHA256HashWithBitcodes + for (auto &Entry : SHA256HashWithBitcodes) + Entry.second.clear(); + SHA256HashWithBitcodes.clear(); } JitCache CodeCache; JitStorageCache StorageCache; std::string DeviceArch; std::unordered_map VarNameToDevPtr; - void linkJitModule(Module *M, LLVMContext *Ctx, StringRef KernelName, - SmallVector> &LinkedModules); + void linkJitModule(Module &M, StringRef KernelName, + ArrayRef> LinkedModules); + std::string + getCombinedModuleHash(ArrayRef> LinkedModules); + + // All modules are associated with context, to guarantee correct lifetime. + std::unique_ptr ProteusCtx; private: // This map is private and only accessible via the API. @@ -435,14 +469,28 @@ JitEngineDevice::compileAndRun( uint64_t ShmemSize, typename DeviceTraits::DeviceStream_t Stream) { TIMESCOPE("compileAndRun"); + // This was never registered, return immediately + if (!KernelToHandleMap.contains(Kernel)) + return launchKernelDirect(Kernel, GridDim, BlockDim, KernelArgs, ShmemSize, + Stream); + SmallVector RCsVec; getRuntimeConstantValues(KernelArgs, RCIndices, RCTypes, RCsVec); - uint64_t HashValue = CodeCache.hash(ModuleUniqueId, KernelName, RCsVec.data(), - NumRuntimeConstants); - typename DeviceTraits::KernelFunction_t KernelFunc = - CodeCache.lookup(HashValue); + typename DeviceTraits::KernelFunction_t KernelFunc; + + auto Index = KernelToBitcodeIndex.contains(Kernel) + ? KernelToBitcodeIndex[Kernel] + : extractDeviceBitcode(KernelName, Kernel); + + // I have already read the LLVM IR from the Binary. Pick the Static Hash + auto StaticHash = SHA256HashWithBitcodes[Index].first; + // TODO: This does not include the GridDims/BlockDims. We need to fix it. + uint64_t DynamicHashValue = CodeCache.hash( + StaticHash, KernelName, RCsVec.data(), NumRuntimeConstants); + KernelFunc = CodeCache.lookup(DynamicHashValue); + // We found the kernel, execute if (KernelFunc) return launchKernelFunction(KernelFunc, GridDim, BlockDim, KernelArgs, ShmemSize, Stream); @@ -450,11 +498,12 @@ JitEngineDevice::compileAndRun( // NOTE: we don't need a suffix to differentiate kernels, each specialization // will be in its own module uniquely identify by HashValue. It exists only // for debugging purposes to verify that the jitted kernel executes. - std::string Suffix = mangleSuffix(HashValue); + std::string Suffix = mangleSuffix(DynamicHashValue); std::string KernelMangled = (KernelName + Suffix).str(); if (Config.ENV_PROTEUS_USE_STORED_CACHE) { - // If there device global variables, lookup the IR and codegen object + // FIXME: The code cache is completely broken as of now. I need to revisit + // this. If there device global variables, lookup the IR and codegen object // before launching. Else, if there aren't device global variables, lookup // the object and launch. @@ -466,13 +515,13 @@ JitEngineDevice::compileAndRun( bool HasDeviceGlobals = !VarNameToDevPtr.empty(); if (auto CacheBuf = (HasDeviceGlobals - ? StorageCache.lookupBitcode(HashValue, KernelMangled) - : StorageCache.lookupObject(HashValue, KernelMangled))) { + ? StorageCache.lookupBitcode(DynamicHashValue, KernelMangled) + : StorageCache.lookupObject(DynamicHashValue, + KernelMangled))) { std::unique_ptr ObjBuf; if (HasDeviceGlobals) { - auto Ctx = std::make_unique(); SMDiagnostic Err; - auto M = parseIR(CacheBuf->getMemBufferRef(), Err, *Ctx); + auto M = parseIR(CacheBuf->getMemBufferRef(), Err, *ProteusCtx.get()); relinkGlobals(*M, VarNameToDevPtr); ObjBuf = codegenObject(*M, DeviceArch); } else { @@ -482,7 +531,7 @@ JitEngineDevice::compileAndRun( auto KernelFunc = getKernelFunctionFromImage(KernelMangled, ObjBuf->getBufferStart()); - CodeCache.insert(HashValue, KernelFunc, KernelName, RCsVec.data(), + CodeCache.insert(DynamicHashValue, KernelFunc, KernelName, RCsVec.data(), NumRuntimeConstants); return launchKernelFunction(KernelFunc, GridDim, BlockDim, KernelArgs, @@ -490,22 +539,18 @@ JitEngineDevice::compileAndRun( } } - auto IRBuffer = extractDeviceBitcode(KernelName, Kernel); - - auto parseBitcode = [&]() -> Expected { - auto Ctx = std::make_unique(); - SMDiagnostic Err; - if (auto M = parseIR(IRBuffer->getMemBufferRef(), Err, *Ctx)) - return orc::ThreadSafeModule(std::move(M), std::move(Ctx)); - - return createSMDiagnosticError(Err); - }; - - auto SafeModule = parseBitcode(); - if (auto E = SafeModule.takeError()) - FATAL_ERROR(toString(std::move(E)).c_str()); + if (!LinkedLLVMIRModules.contains(Kernel)) { + // if we get here, we have access to the LLVM-IR of the module, but we + // have never linked everything together and internalized the symbols. + LinkedLLVMIRModules.insert( + {Kernel, + createLinkedModule(SHA256HashWithBitcodes[Index].second, KernelName)}); + } - auto *JitModule = SafeModule->getModuleUnlocked(); + // We need to clone, The JitModule will be specialized later, and we need + // the one stored under LinkedLLVMIRModules to be a generic version prior + // specialization. + auto JitModule = llvm::CloneModule(*LinkedLLVMIRModules[Kernel]); specializeIR(*JitModule, KernelName, Suffix, BlockDim, GridDim, RCIndices, RCsVec.data(), NumRuntimeConstants); @@ -517,18 +562,18 @@ JitEngineDevice::compileAndRun( SmallString<4096> ModuleBuffer; raw_svector_ostream ModuleBufferOS(ModuleBuffer); WriteBitcodeToFile(*JitModule, ModuleBufferOS); - StorageCache.storeBitcode(HashValue, ModuleBuffer); + StorageCache.storeBitcode(DynamicHashValue, ModuleBuffer); relinkGlobals(*JitModule, VarNameToDevPtr); auto ObjBuf = codegenObject(*JitModule, DeviceArch); if (Config.ENV_PROTEUS_USE_STORED_CACHE) - StorageCache.storeObject(HashValue, ObjBuf->getMemBufferRef()); + StorageCache.storeObject(DynamicHashValue, ObjBuf->getMemBufferRef()); KernelFunc = getKernelFunctionFromImage(KernelMangled, ObjBuf->getBufferStart()); - CodeCache.insert(HashValue, KernelFunc, KernelName, RCsVec.data(), + CodeCache.insert(DynamicHashValue, KernelFunc, KernelName, RCsVec.data(), NumRuntimeConstants); return launchKernelFunction(KernelFunc, GridDim, BlockDim, KernelArgs, @@ -587,6 +632,10 @@ void JitEngineDevice::registerFunction(void *Handle, void *Kernel, "Expected kernel inserted only once in the map"); KernelToHandleMap[Kernel] = Handle; + assert(!ProteusAnnotatedKernels.contains(Kernel) && + "Expected kernel inserted only once in proteus kernel map"); + ProteusAnnotatedKernels.insert(Kernel); + JITKernelInfoMap[Kernel] = JITKernelInfo(KernelName, RCIndices, RCTypes, NumRCs); } @@ -608,15 +657,43 @@ void JitEngineDevice::registerLinkedBinary( ModuleIdToFatBinary[ModuleId] = FatbinWrapper; } +template +std::string JitEngineDevice::getCombinedModuleHash( + ArrayRef> LinkedModules) { + SmallVector SHA256HashCodes; + for (auto &Mod : LinkedModules) { + NamedMDNode *ProteusSHANode = + Mod->getNamedMetadata("proteus.module.sha256"); + assert(ProteusSHANode != nullptr && + "Expected non-null proteus.module.sha256 metadata"); + assert(ProteusSHANode->getNumOperands() == 1 && + "Hash MD Node should have a single operand"); + auto MDHash = ProteusSHANode->getOperand(0); + MDString *sha256 = dyn_cast(MDHash->getOperand(0)); + if (!sha256) { + FATAL_ERROR("Could not read sha256 from module\n"); + } + SHA256HashCodes.push_back(sha256->getString().str()); + Mod->eraseNamedMetadata(ProteusSHANode); + } + + std::sort(SHA256HashCodes.begin(), SHA256HashCodes.end()); + std::string combinedHash; + for (auto hash : SHA256HashCodes) { + combinedHash += hash; + } + return combinedHash; +} + template void JitEngineDevice::linkJitModule( - Module *M, LLVMContext *Ctx, StringRef KernelName, - SmallVector> &LinkedModules) { + Module &M, StringRef KernelName, + ArrayRef> LinkedModules) { if (LinkedModules.empty()) FATAL_ERROR("Expected jit module"); - Linker IRLinker(*M); - for (auto &LinkedM : LinkedModules) { + Linker IRLinker(M); + for (auto &LinkedM : llvm::reverse(LinkedModules)) { // Returns true if linking failed. if (IRLinker.linkInModule(std::move(LinkedM))) FATAL_ERROR("Linking failed"); diff --git a/lib/JitEngineDeviceCUDA.cpp b/lib/JitEngineDeviceCUDA.cpp index c69dab99..f7c8d3d7 100644 --- a/lib/JitEngineDeviceCUDA.cpp +++ b/lib/JitEngineDeviceCUDA.cpp @@ -12,6 +12,7 @@ #include "llvm/IR/Metadata.h" #include "llvm/Support/TargetSelect.h" #include "llvm/Target/TargetMachine.h" +#include #include #include #include @@ -65,18 +66,28 @@ void JitEngineDeviceCUDA::extractLinkedBitcode( if (!M) FATAL_ERROR("unexpected"); - LinkedModules.push_back(std::move(M)); + LinkedModules.emplace_back(std::move(M)); } -std::unique_ptr -JitEngineDeviceCUDA::extractDeviceBitcode(StringRef KernelName, void *Kernel) { +std::unique_ptr JitEngineDeviceCUDA::createLinkedModule( + ArrayRef> LinkedModules, + StringRef KernelName) { + auto JitModule = std::make_unique("JitModule", *ProteusCtx); + linkJitModule(*JitModule, KernelName, LinkedModules); + return std::move(JitModule); +} + +int JitEngineDeviceCUDA::extractDeviceBitcode(StringRef KernelName, + void *Kernel) { CUmodule CUMod; CUdeviceptr DevPtr; + size_t Bytes; + if (KernelToBitcodeIndex.contains(Kernel)) + return KernelToBitcodeIndex[Kernel]; + SmallVector> LinkedModules; - auto Ctx = std::make_unique(); - auto JitModule = std::make_unique("JitModule", *Ctx); if (!KernelToHandleMap.contains(Kernel)) FATAL_ERROR("Expected Kernel in map"); @@ -93,21 +104,30 @@ JitEngineDeviceCUDA::extractDeviceBitcode(StringRef KernelName, void *Kernel) { cuErrCheck(cuModuleLoadData(&CUMod, FatbinWrapper->Binary)); for (auto &ModuleId : LinkedModuleIds) - extractLinkedBitcode(*Ctx.get(), CUMod, LinkedModules, ModuleId); + extractLinkedBitcode(*ProteusCtx, CUMod, LinkedModules, ModuleId); for (auto &ModuleId : GlobalLinkedModuleIds) - extractLinkedBitcode(*Ctx.get(), CUMod, LinkedModules, ModuleId); + extractLinkedBitcode(*ProteusCtx, CUMod, LinkedModules, ModuleId); cuErrCheck(cuModuleUnload(CUMod)); - linkJitModule(JitModule.get(), Ctx.get(), KernelName, LinkedModules); + // Store the linked modules. For future accesses + int index = SHA256HashWithBitcodes.size(); + SHA256HashWithBitcodes.push_back( + std::make_pair(getCombinedModuleHash(LinkedModules), + SmallVector>())); + // Iterate and pop elements + for (auto it = LinkedModules.rbegin(); it != LinkedModules.rend(); ++it) { + SHA256HashWithBitcodes[index].second.push_back(std::move(*it)); + } - std::string LinkedDeviceBitcode; - raw_string_ostream OS(LinkedDeviceBitcode); - WriteBitcodeToFile(*JitModule.get(), OS); - OS.flush(); + for (const auto &KV : KernelToHandleMap) { + if (KV.second == Handle) { + KernelToBitcodeIndex.try_emplace(KV.first, index); + } + } - return MemoryBuffer::getMemBufferCopy(LinkedDeviceBitcode); + return index; } void JitEngineDeviceCUDA::setLaunchBoundsForKernel(Module &M, Function &F, @@ -222,6 +242,7 @@ JitEngineDeviceCUDA::codegenObject(Module &M, StringRef DeviceArch) { nvPTXCompilerErrCheck( nvPTXCompilerCreate(&PTXCompiler, PTXStr.size(), PTXStr.data())); std::string ArchOpt = ("--gpu-name=" + DeviceArch).str(); + std::string RDCOption = ""; if (!GlobalLinkedBinaries.empty()) RDCOption = "-c"; diff --git a/lib/JitEngineDeviceCUDA.hpp b/lib/JitEngineDeviceCUDA.hpp index cb1674eb..95e578ac 100644 --- a/lib/JitEngineDeviceCUDA.hpp +++ b/lib/JitEngineDeviceCUDA.hpp @@ -84,8 +84,11 @@ class JitEngineDeviceCUDA : public JitEngineDevice { void setLaunchBoundsForKernel(Module &M, Function &F, size_t GridSize, int BlockSize); - std::unique_ptr extractDeviceBitcode(StringRef KernelName, - void *Kernel); + int extractDeviceBitcode(StringRef KernelName, void *Kernel); + + std::unique_ptr + createLinkedModule(ArrayRef> LinkedModules, + StringRef KernelName); void codegenPTX(Module &M, StringRef DeviceArch, SmallVectorImpl &PTXStr); diff --git a/pass/ProteusPass.cpp b/pass/ProteusPass.cpp index 6eb243de..fc3e2b71 100644 --- a/pass/ProteusPass.cpp +++ b/pass/ProteusPass.cpp @@ -35,6 +35,7 @@ #include "llvm/Passes/PassBuilder.h" #include "llvm/Passes/PassPlugin.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/SHA256.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/IPO/AlwaysInliner.h" @@ -55,6 +56,7 @@ #include #include #include +#include #include #include #include @@ -117,6 +119,22 @@ class ProteusJitPassImpl { StructType::create({Int128Ty, Int32Ty}, "struct.args", true); } + // Function to attach the SHA-256 checksum as custom metadata to the Module + NamedMDNode *addSHA256AsMetadata(Module &M, + const std::string &SHA256Checksum) { + // Get or create a named metadata node + NamedMDNode *SHA256Metadata = + M.getOrInsertNamedMetadata("proteus.module.sha256"); + + // Create a metadata node with the SHA-256 hash as a string + MDNode *SHA256Node = MDNode::get( + M.getContext(), MDString::get(M.getContext(), SHA256Checksum)); + + // Add the metadata node as an operand to the named metadata + SHA256Metadata->addOperand(SHA256Node); + return SHA256Metadata; + } + bool run(Module &M, bool IsLTO) { parseAnnotations(M); @@ -131,7 +149,15 @@ class ProteusJitPassImpl { // For device compilation, just extract the module IR of device code // and return. if (isDeviceCompilation(M)) { + // Calling the SHA256 on top of the module, as `hash_code` is not + // deterministic across executions + auto proteusHashMDNode = + addSHA256AsMetadata(M, getModuleBitcodeSHA256(M)); + std::cout << "Proteus hash is :\n"; + proteusHashMDNode->dump(); emitJitModuleDevice(M, IsLTO); + // The AOT compilation does not need the SHA, so we delete it + M.eraseNamedMetadata(proteusHashMDNode); return true; } @@ -1100,6 +1126,22 @@ class ProteusJitPassImpl { } } } + + // Function to get the SHA-256 checksum of a Module's bitcode + std::string getModuleBitcodeSHA256(const Module &M) { + std::string BitcodeStr; + llvm::raw_string_ostream BitcodeStream(BitcodeStr); + + WriteBitcodeToFile(M, BitcodeStream); + BitcodeStream.flush(); + + llvm::SHA256 SHA256Hasher; + SHA256Hasher.update(BitcodeStr); // Feed the bitcode to the hasher + auto Digest = SHA256Hasher.final(); // Finalize the hash computation + + // Convert the SHA-256 result to a human-readable hexadecimal string + return llvm::toHex(Digest); + } }; // New PM implementation. From 4f5b005d3da47b701e63ada968ce8173d60ee093 Mon Sep 17 00:00:00 2001 From: koparasy Date: Wed, 18 Dec 2024 08:46:45 -0800 Subject: [PATCH 2/5] Add Static Global Hash computation --- lib/JitCache.hpp | 16 ++++++--- lib/JitEngineDevice.hpp | 70 +++++++++++++++++++++++++++++++++++-- lib/JitEngineDeviceCUDA.hpp | 2 ++ lib/JitEngineDeviceHIP.hpp | 2 ++ pass/ProteusPass.cpp | 35 ++++++++++--------- 5 files changed, 102 insertions(+), 23 deletions(-) diff --git a/lib/JitCache.hpp b/lib/JitCache.hpp index 26361b38..15a93b0f 100644 --- a/lib/JitCache.hpp +++ b/lib/JitCache.hpp @@ -34,10 +34,20 @@ inline hash_code hash_value(const proteus::RuntimeConstant &RC) { template class JitCache { public: + uint64_t hash(StringRef ModuleUniqueId, StringRef FnName, const dim3 &GridDim, + const dim3 &BlockDim, const RuntimeConstant *RC, + int NumRuntimeConstants) const { + ArrayRef Data(RC, NumRuntimeConstants); + auto HashValue = + hash_combine(ModuleUniqueId, FnName, GridDim.x, GridDim.y, GridDim.z, + BlockDim.x, BlockDim.y, BlockDim.z, Data); + return HashValue; + } + uint64_t hash(StringRef ModuleUniqueId, StringRef FnName, const RuntimeConstant *RC, int NumRuntimeConstants) const { ArrayRef Data(RC, NumRuntimeConstants); - auto HashValue = hash_combine(ExePath, ModuleUniqueId, FnName, Data); + auto HashValue = hash_combine(ModuleUniqueId, FnName, Data); return HashValue; } @@ -98,7 +108,6 @@ template class JitCache { JitCache() { // NOTE: Linux-specific. - ExePath = std::filesystem::canonical("/proc/self/exe"); } private: @@ -115,11 +124,10 @@ template class JitCache { DenseMap CacheMap; // Use the executable binary path when hashing to differentiate between // same-named kernels generated by other executables. - std::filesystem::path ExePath; uint64_t Hits = 0; uint64_t Accesses = 0; }; } // namespace proteus -#endif \ No newline at end of file +#endif diff --git a/lib/JitEngineDevice.hpp b/lib/JitEngineDevice.hpp index f02a92ba..08b79021 100644 --- a/lib/JitEngineDevice.hpp +++ b/lib/JitEngineDevice.hpp @@ -12,9 +12,12 @@ #define PROTEUS_JITENGINEDEVICE_HPP #include "llvm/Linker/Linker.h" +#include "llvm/Object/ELFObjectFile.h" +#include "llvm/Support/SHA256.h" #include #include #include +#include #include #include #include @@ -317,8 +320,61 @@ template class JitEngineDevice : protected JitEngine { relinkGlobals(Module &M, std::unordered_map &VarNameToDevPtr); + static std::string computeDeviceFatBinHash() { + TIMESCOPE("computeDeviceFatBinHash"); + using namespace llvm::object; + llvm::SHA256 sha256; + auto ExePath = std::filesystem::canonical("/proc/self/exe"); + + std::cout << "Reading file from path " << ExePath.string() << "\n"; + + auto bufferOrErr = MemoryBuffer::getFile(ExePath.string()); + if (!bufferOrErr) { + FATAL_ERROR("Failed to open binary file"); + } + + auto objOrErr = + ObjectFile::createELFObjectFile(bufferOrErr.get()->getMemBufferRef()); + if (!objOrErr) { + FATAL_ERROR("Failed to create Object File"); + } + + ObjectFile &elfObj = **objOrErr; + + // Step 3: Iterate through sections and get their contents + for (const SectionRef §ion : elfObj.sections()) { + auto nameOrErr = section.getName(); + if (!nameOrErr) + FATAL_ERROR("Error getting section name: "); + + StringRef sectionName = nameOrErr.get(); + if (sectionName.compare(ImplT::getFatBinSectionName()) != 0) + continue; + + // Get the contents of the section + auto contentsOrErr = section.getContents(); + if (!contentsOrErr) { + FATAL_ERROR("Error getting section contents: "); + continue; + } + StringRef sectionContents = contentsOrErr.get(); + + // Print section name and size + outs() << "Section: " << sectionName + << ", Size: " << sectionContents.size() << " bytes\n"; + sha256.update(sectionContents); + break; + } + auto sha256Hash = sha256.final(); + return llvm::toHex(sha256Hash); + } + protected: - JitEngineDevice() { ProteusCtx = std::make_unique(); } + JitEngineDevice() { + ProteusCtx = std::make_unique(); + ProteusDeviceBinHash = computeDeviceFatBinHash(); + std::cout << "Device Bin Hash is " << ProteusDeviceBinHash << "\n"; + } ~JitEngineDevice() { CodeCache.printStats(); StorageCache.printStats(); @@ -345,6 +401,7 @@ template class JitEngineDevice : protected JitEngine { // All modules are associated with context, to guarantee correct lifetime. std::unique_ptr ProteusCtx; + std::string ProteusDeviceBinHash; private: // This map is private and only accessible via the API. @@ -487,9 +544,16 @@ JitEngineDevice::compileAndRun( // I have already read the LLVM IR from the Binary. Pick the Static Hash auto StaticHash = SHA256HashWithBitcodes[Index].first; // TODO: This does not include the GridDims/BlockDims. We need to fix it. - uint64_t DynamicHashValue = CodeCache.hash( - StaticHash, KernelName, RCsVec.data(), NumRuntimeConstants); + auto PersistentHash = ProteusDeviceBinHash; + uint64_t DynamicHashValue = + CodeCache.hash(PersistentHash, KernelName, GridDim, BlockDim, + RCsVec.data(), NumRuntimeConstants); KernelFunc = CodeCache.lookup(DynamicHashValue); + std::cout << " Function with name " << KernelName.str() << "at address " + << Kernel << " has PersistentHash " << PersistentHash + << " Static Hash:" << StaticHash + << " Dynamic Hash:" << DynamicHashValue << "\n"; + // We found the kernel, execute if (KernelFunc) return launchKernelFunction(KernelFunc, GridDim, BlockDim, KernelArgs, diff --git a/lib/JitEngineDeviceCUDA.hpp b/lib/JitEngineDeviceCUDA.hpp index 95e578ac..11cc83e9 100644 --- a/lib/JitEngineDeviceCUDA.hpp +++ b/lib/JitEngineDeviceCUDA.hpp @@ -79,6 +79,8 @@ class JitEngineDeviceCUDA : public JitEngineDevice { return "llvm.nvvm.read.ptx.sreg.tid.z"; }; + static const char *getFatBinSectionName() { return ".nv_fatbin"; } + void *resolveDeviceGlobalAddr(const void *Addr); void setLaunchBoundsForKernel(Module &M, Function &F, size_t GridSize, diff --git a/lib/JitEngineDeviceHIP.hpp b/lib/JitEngineDeviceHIP.hpp index 10c02805..09c5a237 100644 --- a/lib/JitEngineDeviceHIP.hpp +++ b/lib/JitEngineDeviceHIP.hpp @@ -77,6 +77,8 @@ class JitEngineDeviceHIP : public JitEngineDevice { return "_ZNK17__HIP_CoordinatesI14__HIP_BlockIdxE3__ZcvjEv"; }; + static const char *getFatBinSectionName() { return ".hip_fatbin"; } + void *resolveDeviceGlobalAddr(const void *Addr); void setLaunchBoundsForKernel(Module &M, Function &F, size_t GridSize, diff --git a/pass/ProteusPass.cpp b/pass/ProteusPass.cpp index fc3e2b71..8bc506cf 100644 --- a/pass/ProteusPass.cpp +++ b/pass/ProteusPass.cpp @@ -463,9 +463,9 @@ class ProteusJitPassImpl { std::string GVName = (IsLTO ? "__jit_bitcode_lto" : getJitBitcodeUniqueName(M)); - // NOTE: HIP compilation supports custom section in the binary to store the - // IR. CUDA does not, hence we parse the IR by reading the global from the - // device memory. + // NOTE: HIP compilation supports custom section in the binary to store + // the IR. CUDA does not, hence we parse the IR by reading the global + // from the device memory. Constant *JitModule = ConstantDataArray::get( M.getContext(), ArrayRef((const uint8_t *)BitcodeStr.data(), BitcodeStr.size())); @@ -585,8 +585,9 @@ class ProteusJitPassImpl { } Value *getStubGV(Value *Operand) { - // NOTE: when called by isDeviceKernelHostStub, Operand may not be a global - // variable point to the stub, so we check and return null in that case. + // NOTE: when called by isDeviceKernelHostStub, Operand may not be a + // global variable point to the stub, so we check and return null in that + // case. Value *V = nullptr; #if ENABLE_HIP // NOTE: Hip creates a global named after the device kernel function that @@ -1031,8 +1032,8 @@ class ProteusJitPassImpl { ArrayType::get(Int32Ty, NumRuntimeConstants); IRBuilder<> Builder(RegisterCB->getNextNode()); - // Create an array representing the indices of the args which are runtime - // constants. + // Create an array representing the indices of the args which are + // runtime constants. Value *RuntimeConstantsIndicesAlloca = Builder.CreateAlloca(RuntimeConstantArrayTy); assert(RuntimeConstantsIndicesAlloca && @@ -1081,14 +1082,16 @@ class ProteusJitPassImpl { } void findJitVariables(Module &M) { - DEBUG(Logger::logs("proteus-pass") << "finding jit variables" << "\n"); - DEBUG(Logger::logs("proteus-pass") << "users..." << "\n"); + DEBUG(Logger::logs("proteus-pass") << "finding jit variables" + << "\n"); + DEBUG(Logger::logs("proteus-pass") << "users..." + << "\n"); SmallVector JitFunctions; for (auto &F : M.getFunctionList()) { - // TODO: Demangle and search for the fully qualified proteus::jit_variable - // name. + // TODO: Demangle and search for the fully qualified + // proteus::jit_variable name. if (F.getName().contains("jit_variable")) { JitFunctions.push_back(&F); } @@ -1118,8 +1121,8 @@ class ProteusJitPassImpl { DEBUG(Logger::logs("proteus-pass") << "slot: " << *Slot << "\n"); CB->setArgOperand(1, Slot); } else { - DEBUG(Logger::logs("proteus-pass") - << "no gep, assuming slot 0" << "\n"); + DEBUG(Logger::logs("proteus-pass") << "no gep, assuming slot 0" + << "\n"); Constant *C = ConstantInt::get(Int32Ty, 0); CB->setArgOperand(1, C); } @@ -1160,9 +1163,9 @@ struct ProteusJitPass : PassInfoMixin { return PreservedAnalyses::all(); } - // Without isRequired returning true, this pass will be skipped for functions - // decorated with the optnone LLVM attribute. Note that clang -O0 decorates - // all functions with optnone. + // Without isRequired returning true, this pass will be skipped for + // functions decorated with the optnone LLVM attribute. Note that clang -O0 + // decorates all functions with optnone. static bool isRequired() { return true; } }; From 1b5660ca6a4d6cf799969702656f5a4b50c0e28d Mon Sep 17 00:00:00 2001 From: koparasy Date: Wed, 18 Dec 2024 11:48:26 -0800 Subject: [PATCH 3/5] Compute SHA256 of module --- lib/CompilerInterfaceDevice.cpp | 12 ++++-- lib/JitEngineDevice.hpp | 19 +++++---- pass/ProteusPass.cpp | 75 ++++++++++++++++++++++++++++++--- 3 files changed, 89 insertions(+), 17 deletions(-) diff --git a/lib/CompilerInterfaceDevice.cpp b/lib/CompilerInterfaceDevice.cpp index 01f4da7c..58035eb2 100644 --- a/lib/CompilerInterfaceDevice.cpp +++ b/lib/CompilerInterfaceDevice.cpp @@ -10,6 +10,8 @@ #include "CompilerInterfaceDevice.h" #include "JitEngineDevice.hpp" +#include +#include using namespace proteus; @@ -46,8 +48,12 @@ __jit_register_linked_binary(void *FatbinWrapper, const char *ModuleId) { } extern "C" __attribute((used)) void -__jit_register_function(void *Handle, void *Kernel, char *KernelName, - int32_t *RCIndices, int32_t *RCTypes, int32_t NumRCs) { +__jit_register_function(void *Handle, const char *ModuleSHA256, void *Kernel, + char *KernelName, int32_t *RCIndices, int32_t *RCTypes, + int32_t NumRCs) { auto &Jit = JitDeviceImplT::instance(); - Jit.registerFunction(Handle, Kernel, KernelName, RCIndices, RCTypes, NumRCs); + std::cout << "Got ModuleSHA256 " + << llvm::toHex(StringRef(ModuleSHA256, 32).str()) << "\n"; + Jit.registerFunction(Handle, ModuleSHA256, Kernel, KernelName, RCIndices, + RCTypes, NumRCs); } diff --git a/lib/JitEngineDevice.hpp b/lib/JitEngineDevice.hpp index 08b79021..0f53458a 100644 --- a/lib/JitEngineDevice.hpp +++ b/lib/JitEngineDevice.hpp @@ -59,15 +59,17 @@ namespace proteus { using namespace llvm; class JITKernelInfo { + StringRef SHA256; char const *Name; SmallVector RCTypes; SmallVector RCIndices; int32_t NumRCs; public: - JITKernelInfo(char const *Name, int32_t *RCIndices, int32_t *RCTypes, - int32_t NumRCs) - : Name(Name), RCIndices{ArrayRef{RCIndices, static_cast(NumRCs)}}, + JITKernelInfo(const char *SHA256, char const *Name, int32_t *RCIndices, + int32_t *RCTypes, int32_t NumRCs) + : SHA256(SHA256, 32), Name(Name), + RCIndices{ArrayRef{RCIndices, static_cast(NumRCs)}}, RCTypes{ArrayRef{RCTypes, static_cast(NumRCs)}}, NumRCs(NumRCs) {} @@ -109,8 +111,9 @@ template class JitEngineDevice : protected JitEngine { void registerFatBinary(void *Handle, FatbinWrapper_t *FatbinWrapper, const char *ModuleId); void registerFatBinaryEnd(); - void registerFunction(void *Handle, void *Kernel, char *KernelName, - int32_t *RCIndices, int32_t *RCTypes, int32_t NumRCs); + void registerFunction(void *Handle, const char *SHA256, void *Kernel, + char *KernelName, int32_t *RCIndices, int32_t *RCTypes, + int32_t NumRCs); struct BinaryInfo { FatbinWrapper_t *FatbinWrapper; @@ -685,8 +688,8 @@ template void JitEngineDevice::registerFatBinaryEnd() { } template -void JitEngineDevice::registerFunction(void *Handle, void *Kernel, - char *KernelName, +void JitEngineDevice::registerFunction(void *Handle, const char *SHA256, + void *Kernel, char *KernelName, int32_t *RCIndices, int32_t *RCTypes, int32_t NumRCs) { @@ -701,7 +704,7 @@ void JitEngineDevice::registerFunction(void *Handle, void *Kernel, ProteusAnnotatedKernels.insert(Kernel); JITKernelInfoMap[Kernel] = - JITKernelInfo(KernelName, RCIndices, RCTypes, NumRCs); + JITKernelInfo(SHA256, KernelName, RCIndices, RCTypes, NumRCs); } template diff --git a/pass/ProteusPass.cpp b/pass/ProteusPass.cpp index 8bc506cf..9a2c791c 100644 --- a/pass/ProteusPass.cpp +++ b/pass/ProteusPass.cpp @@ -44,7 +44,9 @@ #include "llvm/Transforms/IPO/StripSymbols.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/ModuleUtils.h" +#include #include +#include #include #include #include @@ -91,6 +93,7 @@ constexpr char const *RegisterFunctionName = "__cudaRegisterFunction"; constexpr char const *LaunchFunctionName = "cudaLaunchKernel"; constexpr char const *RegisterVarName = "__cudaRegisterVar"; constexpr char const *RegisterFatBinaryName = "__cudaRegisterFatBinary"; +constexpr char const *FatBinWrapperSymbolName = "__cuda_fatbin_wrapper"; #else constexpr char const *RegisterFunctionName = nullptr; constexpr char const *LaunchFunctionName = nullptr; @@ -153,8 +156,6 @@ class ProteusJitPassImpl { // deterministic across executions auto proteusHashMDNode = addSHA256AsMetadata(M, getModuleBitcodeSHA256(M)); - std::cout << "Proteus hash is :\n"; - proteusHashMDNode->dump(); emitJitModuleDevice(M, IsLTO); // The AOT compilation does not need the SHA, so we delete it M.eraseNamedMetadata(proteusHashMDNode); @@ -198,6 +199,16 @@ class ProteusJitPassImpl { if (verifyModule(M, &errs())) FATAL_ERROR("Broken original module found, compilation aborted!"); + std::filesystem::path ModulePath(M.getSourceFileName()); + std::filesystem::path filename(M.getSourceFileName()); + std::string rrBC(Twine(filename.filename().string(), ".host.bc").str()); + std::error_code EC; + raw_fd_ostream OutBC(rrBC, EC); + if (EC) + throw std::runtime_error("Cannot open device code " + rrBC); + OutBC << M; + OutBC.close(); + return true; } @@ -954,14 +965,15 @@ class ProteusJitPassImpl { FunctionCallee getJitRegisterFunctionFn(Module &M) { // The prototype is // __jit_register_function(void *Handle, + // const char ModuleSHA256, // void *Kernel, // char const *KernelName, // int32_t *RCIndices, // int32_t *RCTypes, // int32_t NumRCs) - FunctionType *JitRegisterFunctionFnTy = - FunctionType::get(VoidTy, {PtrTy, PtrTy, PtrTy, PtrTy, PtrTy, Int32Ty}, - /* isVarArg=*/false); + FunctionType *JitRegisterFunctionFnTy = FunctionType::get( + VoidTy, {PtrTy, PtrTy, PtrTy, PtrTy, PtrTy, PtrTy, Int32Ty}, + /* isVarArg=*/false); FunctionCallee JitRegisterKernelFn = M.getOrInsertFunction( "__jit_register_function", JitRegisterFunctionFnTy); @@ -1072,15 +1084,66 @@ class ProteusJitPassImpl { FunctionCallee JitRegisterFunction = getJitRegisterFunctionFn(M); + auto SHA256Global = getOrCreateModuleSHAGlobal(M); + Value *SHAValue = Builder.CreateBitCast(SHA256Global, PtrTy); + SHAValue->dump(); + constexpr int StubOperand = 1; Builder.CreateCall( JitRegisterFunction, - {RegisterCB->getArgOperand(0), RegisterCB->getArgOperand(1), + {RegisterCB->getArgOperand(0), SHAValue, RegisterCB->getArgOperand(1), RegisterCB->getArgOperand(2), RuntimeConstantsIndicesAlloca, RuntimeConstantsTypesAlloca, NumRCsValue}); } } + GlobalVariable *computeSHA256(Module &M, GlobalVariable *GV) { + assert(GV->hasInitializer() && + "Global Variable must have initializer to compute the SHA256"); + LLVMContext &Context = M.getContext(); + auto *LLVMDeviceIR = dyn_cast(GV->getInitializer()); + StringRef Data = LLVMDeviceIR->getRawDataValues(); + SHA256 Hash; + Hash.update(Data); + auto HashValue = Hash.result(); + std::cout << "Hash value is " << llvm::toHex(HashValue) << "\n"; + std::vector HashBytes; + for (uint8_t Byte : HashValue) { + HashBytes.push_back(ConstantInt::get(Type::getInt8Ty(Context), Byte)); + } + + // Create a ConstantDataArray from the hash bytes + ArrayType *HashType = + ArrayType::get(Type::getInt8Ty(Context), HashBytes.size()); + Constant *HashInitializer = ConstantArray::get(HashType, HashBytes); + auto *HashGlobal = new GlobalVariable(M, HashType, /* isConstant */ true, + GlobalValue::PrivateLinkage, + HashInitializer, "sha256_hash"); + // Add a null terminator to the hash bytes + + return HashGlobal; + } + + GlobalVariable *getOrCreateModuleSHAGlobal(Module &M) { + GlobalVariable *ProteusSHA256GV = M.getGlobalVariable("sha256_hash"); + if (ProteusSHA256GV != nullptr) + return ProteusSHA256GV; + + GlobalVariable *FatbinWrapper = + M.getGlobalVariable(FatBinWrapperSymbolName, true); + assert(FatbinWrapper && "Expected existing fat bin wrapper"); + ConstantStruct *C = + dyn_cast(FatbinWrapper->getInitializer()); + assert(C->getType()->getNumElements() == 4 && + "Expected four fields in fatbin wrapper struct"); + constexpr int FatbinField = 2; + auto *Fatbin = C->getAggregateElement(FatbinField); + GlobalVariable *FatbinGV = dyn_cast(Fatbin); + assert(FatbinGV && "Expected global variable for the fatbin object"); + ProteusSHA256GV = computeSHA256(M, FatbinGV); + return ProteusSHA256GV; + } + void findJitVariables(Module &M) { DEBUG(Logger::logs("proteus-pass") << "finding jit variables" << "\n"); From 85208a7e808eadae60576bcb8493b38b344afd3c Mon Sep 17 00:00:00 2001 From: koparasy Date: Wed, 18 Dec 2024 12:05:06 -0800 Subject: [PATCH 4/5] Correct cuda implementation --- lib/CompilerInterfaceDevice.cpp | 2 - lib/JitCache.hpp | 12 +-- lib/JitEngineDevice.hpp | 140 ++++++++++---------------------- lib/JitEngineDeviceCUDA.cpp | 36 +++----- lib/JitEngineDeviceCUDA.hpp | 13 +-- pass/ProteusPass.cpp | 64 ++++----------- 6 files changed, 81 insertions(+), 186 deletions(-) diff --git a/lib/CompilerInterfaceDevice.cpp b/lib/CompilerInterfaceDevice.cpp index 58035eb2..b93ea518 100644 --- a/lib/CompilerInterfaceDevice.cpp +++ b/lib/CompilerInterfaceDevice.cpp @@ -52,8 +52,6 @@ __jit_register_function(void *Handle, const char *ModuleSHA256, void *Kernel, char *KernelName, int32_t *RCIndices, int32_t *RCTypes, int32_t NumRCs) { auto &Jit = JitDeviceImplT::instance(); - std::cout << "Got ModuleSHA256 " - << llvm::toHex(StringRef(ModuleSHA256, 32).str()) << "\n"; Jit.registerFunction(Handle, ModuleSHA256, Kernel, KernelName, RCIndices, RCTypes, NumRCs); } diff --git a/lib/JitCache.hpp b/lib/JitCache.hpp index 15a93b0f..075a2f6e 100644 --- a/lib/JitCache.hpp +++ b/lib/JitCache.hpp @@ -34,13 +34,13 @@ inline hash_code hash_value(const proteus::RuntimeConstant &RC) { template class JitCache { public: - uint64_t hash(StringRef ModuleUniqueId, StringRef FnName, const dim3 &GridDim, - const dim3 &BlockDim, const RuntimeConstant *RC, - int NumRuntimeConstants) const { + uint64_t hash(StringRef ModuleUniqueId, StringRef KernelHash, + StringRef KernelName, const dim3 &GridDim, const dim3 &BlockDim, + const RuntimeConstant *RC, int NumRuntimeConstants) const { ArrayRef Data(RC, NumRuntimeConstants); - auto HashValue = - hash_combine(ModuleUniqueId, FnName, GridDim.x, GridDim.y, GridDim.z, - BlockDim.x, BlockDim.y, BlockDim.z, Data); + auto HashValue = hash_combine(ModuleUniqueId, KernelHash, KernelName, + GridDim.x, GridDim.y, GridDim.z, BlockDim.x, + BlockDim.y, BlockDim.z, Data); return HashValue; } diff --git a/lib/JitEngineDevice.hpp b/lib/JitEngineDevice.hpp index 0f53458a..485a9937 100644 --- a/lib/JitEngineDevice.hpp +++ b/lib/JitEngineDevice.hpp @@ -78,6 +78,7 @@ class JITKernelInfo { const auto &getRCIndices() const { return RCIndices; } const auto &getRCTypes() const { return RCTypes; } const auto &getNumRCs() const { return NumRCs; } + const auto &getSHA256() const { return SHA256; } }; struct FatbinWrapper_t { @@ -125,11 +126,7 @@ template class JitEngineDevice : protected JitEngine { DenseMap HandleToBinaryInfo; DenseMap KernelToHandleMap; - SmallVector>>> - SHA256HashWithBitcodes; - - DenseMap> LinkedLLVMIRModules; - DenseMap KernelToBitcodeIndex; + DenseMap> KernelToLinkedBitcode; /* @Brief After proteus initialization contains all kernels annotathed with * proteus */ DenseSet ProteusAnnotatedKernels; @@ -298,15 +295,7 @@ template class JitEngineDevice : protected JitEngine { Image); } - std::unique_ptr - createLinkedModule(ArrayRef> LinkedModules, - StringRef KernelName) { - TIMESCOPE(__FUNCTION__) - return static_cast(*this).createLinkedModule(LinkedModules, - KernelName); - } - - int extractDeviceBitcode(StringRef KernelName, void *Kernel) { + void extractDeviceBitcode(StringRef KernelName, void *Kernel) { TIMESCOPE(__FUNCTION__) return static_cast(*this).extractDeviceBitcode(KernelName, Kernel); } @@ -344,17 +333,18 @@ template class JitEngineDevice : protected JitEngine { ObjectFile &elfObj = **objOrErr; - // Step 3: Iterate through sections and get their contents for (const SectionRef §ion : elfObj.sections()) { auto nameOrErr = section.getName(); if (!nameOrErr) FATAL_ERROR("Error getting section name: "); StringRef sectionName = nameOrErr.get(); - if (sectionName.compare(ImplT::getFatBinSectionName()) != 0) + if (!ImplT::HashSection(sectionName)) continue; - // Get the contents of the section + DBG(Logger::logs("proteus") + << "Hashing section " << sectionName.str() << "\n"); + auto contentsOrErr = section.getContents(); if (!contentsOrErr) { FATAL_ERROR("Error getting section contents: "); @@ -362,11 +352,7 @@ template class JitEngineDevice : protected JitEngine { } StringRef sectionContents = contentsOrErr.get(); - // Print section name and size - outs() << "Section: " << sectionName - << ", Size: " << sectionContents.size() << " bytes\n"; sha256.update(sectionContents); - break; } auto sha256Hash = sha256.final(); return llvm::toHex(sha256Hash); @@ -376,31 +362,25 @@ template class JitEngineDevice : protected JitEngine { JitEngineDevice() { ProteusCtx = std::make_unique(); ProteusDeviceBinHash = computeDeviceFatBinHash(); - std::cout << "Device Bin Hash is " << ProteusDeviceBinHash << "\n"; + DBG(Logger::logs("proteus") + << "Device Bin Hash is " << ProteusDeviceBinHash << "\n"); } ~JitEngineDevice() { CodeCache.printStats(); StorageCache.printStats(); // Note: We manually clear or unique_ptr to Modules before the destructor // releases the ProteusCtx. - // - // Explicitly clear the LinkedLLVMIRModules - LinkedLLVMIRModules.clear(); - - // Explicitly clear SHA256HashWithBitcodes - for (auto &Entry : SHA256HashWithBitcodes) - Entry.second.clear(); - SHA256HashWithBitcodes.clear(); + + KernelToLinkedBitcode.clear(); } JitCache CodeCache; JitStorageCache StorageCache; std::string DeviceArch; std::unordered_map VarNameToDevPtr; - void linkJitModule(Module &M, StringRef KernelName, - ArrayRef> LinkedModules); - std::string - getCombinedModuleHash(ArrayRef> LinkedModules); + std::unique_ptr + linkJitModule(StringRef KernelName, + SmallVector> &LinkedModules); // All modules are associated with context, to guarantee correct lifetime. std::unique_ptr ProteusCtx; @@ -540,22 +520,17 @@ JitEngineDevice::compileAndRun( typename DeviceTraits::KernelFunction_t KernelFunc; - auto Index = KernelToBitcodeIndex.contains(Kernel) - ? KernelToBitcodeIndex[Kernel] - : extractDeviceBitcode(KernelName, Kernel); - - // I have already read the LLVM IR from the Binary. Pick the Static Hash - auto StaticHash = SHA256HashWithBitcodes[Index].first; - // TODO: This does not include the GridDims/BlockDims. We need to fix it. - auto PersistentHash = ProteusDeviceBinHash; - uint64_t DynamicHashValue = - CodeCache.hash(PersistentHash, KernelName, GridDim, BlockDim, + auto FatBinHash = ProteusDeviceBinHash; + auto KernelHash = JITKernelInfoMap[Kernel].getSHA256(); + uint64_t CombinedHash = + CodeCache.hash(FatBinHash, KernelHash, KernelName, GridDim, BlockDim, RCsVec.data(), NumRuntimeConstants); - KernelFunc = CodeCache.lookup(DynamicHashValue); - std::cout << " Function with name " << KernelName.str() << "at address " - << Kernel << " has PersistentHash " << PersistentHash - << " Static Hash:" << StaticHash - << " Dynamic Hash:" << DynamicHashValue << "\n"; + KernelFunc = CodeCache.lookup(CombinedHash); + + DBG(Logger::logs("proteus") + << " Function with name " << KernelName.str() << "at address " << Kernel + << " has PersistentHash " << FatBinHash << " Static Hash:" + << llvm::toHex(KernelHash) << " Dynamic Hash:" << CombinedHash << "\n"); // We found the kernel, execute if (KernelFunc) @@ -565,7 +540,7 @@ JitEngineDevice::compileAndRun( // NOTE: we don't need a suffix to differentiate kernels, each specialization // will be in its own module uniquely identify by HashValue. It exists only // for debugging purposes to verify that the jitted kernel executes. - std::string Suffix = mangleSuffix(DynamicHashValue); + std::string Suffix = mangleSuffix(CombinedHash); std::string KernelMangled = (KernelName + Suffix).str(); if (Config.ENV_PROTEUS_USE_STORED_CACHE) { @@ -582,9 +557,8 @@ JitEngineDevice::compileAndRun( bool HasDeviceGlobals = !VarNameToDevPtr.empty(); if (auto CacheBuf = (HasDeviceGlobals - ? StorageCache.lookupBitcode(DynamicHashValue, KernelMangled) - : StorageCache.lookupObject(DynamicHashValue, - KernelMangled))) { + ? StorageCache.lookupBitcode(CombinedHash, KernelMangled) + : StorageCache.lookupObject(CombinedHash, KernelMangled))) { std::unique_ptr ObjBuf; if (HasDeviceGlobals) { SMDiagnostic Err; @@ -598,7 +572,7 @@ JitEngineDevice::compileAndRun( auto KernelFunc = getKernelFunctionFromImage(KernelMangled, ObjBuf->getBufferStart()); - CodeCache.insert(DynamicHashValue, KernelFunc, KernelName, RCsVec.data(), + CodeCache.insert(CombinedHash, KernelFunc, KernelName, RCsVec.data(), NumRuntimeConstants); return launchKernelFunction(KernelFunc, GridDim, BlockDim, KernelArgs, @@ -606,18 +580,13 @@ JitEngineDevice::compileAndRun( } } - if (!LinkedLLVMIRModules.contains(Kernel)) { - // if we get here, we have access to the LLVM-IR of the module, but we - // have never linked everything together and internalized the symbols. - LinkedLLVMIRModules.insert( - {Kernel, - createLinkedModule(SHA256HashWithBitcodes[Index].second, KernelName)}); - } + if (!KernelToLinkedBitcode.contains(Kernel)) + extractDeviceBitcode(KernelName, Kernel); // We need to clone, The JitModule will be specialized later, and we need - // the one stored under LinkedLLVMIRModules to be a generic version prior + // the one generic one in KernelToLinkedBitcode to be a generic version prior // specialization. - auto JitModule = llvm::CloneModule(*LinkedLLVMIRModules[Kernel]); + auto JitModule = llvm::CloneModule(*KernelToLinkedBitcode[Kernel]); specializeIR(*JitModule, KernelName, Suffix, BlockDim, GridDim, RCIndices, RCsVec.data(), NumRuntimeConstants); @@ -629,18 +598,18 @@ JitEngineDevice::compileAndRun( SmallString<4096> ModuleBuffer; raw_svector_ostream ModuleBufferOS(ModuleBuffer); WriteBitcodeToFile(*JitModule, ModuleBufferOS); - StorageCache.storeBitcode(DynamicHashValue, ModuleBuffer); + StorageCache.storeBitcode(CombinedHash, ModuleBuffer); relinkGlobals(*JitModule, VarNameToDevPtr); auto ObjBuf = codegenObject(*JitModule, DeviceArch); if (Config.ENV_PROTEUS_USE_STORED_CACHE) - StorageCache.storeObject(DynamicHashValue, ObjBuf->getMemBufferRef()); + StorageCache.storeObject(CombinedHash, ObjBuf->getMemBufferRef()); KernelFunc = getKernelFunctionFromImage(KernelMangled, ObjBuf->getBufferStart()); - CodeCache.insert(DynamicHashValue, KernelFunc, KernelName, RCsVec.data(), + CodeCache.insert(CombinedHash, KernelFunc, KernelName, RCsVec.data(), NumRuntimeConstants); return launchKernelFunction(KernelFunc, GridDim, BlockDim, KernelArgs, @@ -725,46 +694,19 @@ void JitEngineDevice::registerLinkedBinary( } template -std::string JitEngineDevice::getCombinedModuleHash( - ArrayRef> LinkedModules) { - SmallVector SHA256HashCodes; - for (auto &Mod : LinkedModules) { - NamedMDNode *ProteusSHANode = - Mod->getNamedMetadata("proteus.module.sha256"); - assert(ProteusSHANode != nullptr && - "Expected non-null proteus.module.sha256 metadata"); - assert(ProteusSHANode->getNumOperands() == 1 && - "Hash MD Node should have a single operand"); - auto MDHash = ProteusSHANode->getOperand(0); - MDString *sha256 = dyn_cast(MDHash->getOperand(0)); - if (!sha256) { - FATAL_ERROR("Could not read sha256 from module\n"); - } - SHA256HashCodes.push_back(sha256->getString().str()); - Mod->eraseNamedMetadata(ProteusSHANode); - } - - std::sort(SHA256HashCodes.begin(), SHA256HashCodes.end()); - std::string combinedHash; - for (auto hash : SHA256HashCodes) { - combinedHash += hash; - } - return combinedHash; -} - -template -void JitEngineDevice::linkJitModule( - Module &M, StringRef KernelName, - ArrayRef> LinkedModules) { +std::unique_ptr JitEngineDevice::linkJitModule( + StringRef KernelName, SmallVector> &LinkedModules) { if (LinkedModules.empty()) FATAL_ERROR("Expected jit module"); - Linker IRLinker(M); - for (auto &LinkedM : llvm::reverse(LinkedModules)) { + auto LinkedModule = std::make_unique("JitModule", *ProteusCtx); + Linker IRLinker(*LinkedModule); + for (auto &LinkedM : LinkedModules) { // Returns true if linking failed. if (IRLinker.linkInModule(std::move(LinkedM))) FATAL_ERROR("Linking failed"); } + return LinkedModule; } } // namespace proteus diff --git a/lib/JitEngineDeviceCUDA.cpp b/lib/JitEngineDeviceCUDA.cpp index f7c8d3d7..c54f6151 100644 --- a/lib/JitEngineDeviceCUDA.cpp +++ b/lib/JitEngineDeviceCUDA.cpp @@ -66,26 +66,18 @@ void JitEngineDeviceCUDA::extractLinkedBitcode( if (!M) FATAL_ERROR("unexpected"); - LinkedModules.emplace_back(std::move(M)); + LinkedModules.push_back(std::move(M)); } -std::unique_ptr JitEngineDeviceCUDA::createLinkedModule( - ArrayRef> LinkedModules, - StringRef KernelName) { - auto JitModule = std::make_unique("JitModule", *ProteusCtx); - linkJitModule(*JitModule, KernelName, LinkedModules); - return std::move(JitModule); -} - -int JitEngineDeviceCUDA::extractDeviceBitcode(StringRef KernelName, - void *Kernel) { +void JitEngineDeviceCUDA::extractDeviceBitcode(StringRef KernelName, + void *Kernel) { CUmodule CUMod; CUdeviceptr DevPtr; size_t Bytes; - if (KernelToBitcodeIndex.contains(Kernel)) - return KernelToBitcodeIndex[Kernel]; + if (KernelToLinkedBitcode.contains(Kernel)) + return; SmallVector> LinkedModules; if (!KernelToHandleMap.contains(Kernel)) @@ -111,23 +103,18 @@ int JitEngineDeviceCUDA::extractDeviceBitcode(StringRef KernelName, cuErrCheck(cuModuleUnload(CUMod)); - // Store the linked modules. For future accesses - int index = SHA256HashWithBitcodes.size(); - SHA256HashWithBitcodes.push_back( - std::make_pair(getCombinedModuleHash(LinkedModules), - SmallVector>())); - // Iterate and pop elements - for (auto it = LinkedModules.rbegin(); it != LinkedModules.rend(); ++it) { - SHA256HashWithBitcodes[index].second.push_back(std::move(*it)); - } + auto JitModule = + std::shared_ptr(linkJitModule(KernelName, LinkedModules)); for (const auto &KV : KernelToHandleMap) { + // All kernels included in this collection of modules will have an identical + // non specialized IR file. Map all Kernels, to this generic IR file if (KV.second == Handle) { - KernelToBitcodeIndex.try_emplace(KV.first, index); + KernelToLinkedBitcode.try_emplace(KV.first, JitModule); } } - return index; + return; } void JitEngineDeviceCUDA::setLaunchBoundsForKernel(Module &M, Function &F, @@ -242,7 +229,6 @@ JitEngineDeviceCUDA::codegenObject(Module &M, StringRef DeviceArch) { nvPTXCompilerErrCheck( nvPTXCompilerCreate(&PTXCompiler, PTXStr.size(), PTXStr.data())); std::string ArchOpt = ("--gpu-name=" + DeviceArch).str(); - std::string RDCOption = ""; if (!GlobalLinkedBinaries.empty()) RDCOption = "-c"; diff --git a/lib/JitEngineDeviceCUDA.hpp b/lib/JitEngineDeviceCUDA.hpp index 11cc83e9..b9ec7e95 100644 --- a/lib/JitEngineDeviceCUDA.hpp +++ b/lib/JitEngineDeviceCUDA.hpp @@ -15,6 +15,7 @@ #include "Utils.h" #include #include +#include namespace proteus { @@ -79,18 +80,18 @@ class JitEngineDeviceCUDA : public JitEngineDevice { return "llvm.nvvm.read.ptx.sreg.tid.z"; }; - static const char *getFatBinSectionName() { return ".nv_fatbin"; } + static bool HashSection(StringRef sectionName) { + static DenseSet Sections{".nv_fatbin", ".nvFatBinSegment", + "__nv_relfatbin", "__nv_module_id"}; + return Sections.contains(sectionName); + } void *resolveDeviceGlobalAddr(const void *Addr); void setLaunchBoundsForKernel(Module &M, Function &F, size_t GridSize, int BlockSize); - int extractDeviceBitcode(StringRef KernelName, void *Kernel); - - std::unique_ptr - createLinkedModule(ArrayRef> LinkedModules, - StringRef KernelName); + void extractDeviceBitcode(StringRef KernelName, void *Kernel); void codegenPTX(Module &M, StringRef DeviceArch, SmallVectorImpl &PTXStr); diff --git a/pass/ProteusPass.cpp b/pass/ProteusPass.cpp index 9a2c791c..36ad65b1 100644 --- a/pass/ProteusPass.cpp +++ b/pass/ProteusPass.cpp @@ -122,22 +122,6 @@ class ProteusJitPassImpl { StructType::create({Int128Ty, Int32Ty}, "struct.args", true); } - // Function to attach the SHA-256 checksum as custom metadata to the Module - NamedMDNode *addSHA256AsMetadata(Module &M, - const std::string &SHA256Checksum) { - // Get or create a named metadata node - NamedMDNode *SHA256Metadata = - M.getOrInsertNamedMetadata("proteus.module.sha256"); - - // Create a metadata node with the SHA-256 hash as a string - MDNode *SHA256Node = MDNode::get( - M.getContext(), MDString::get(M.getContext(), SHA256Checksum)); - - // Add the metadata node as an operand to the named metadata - SHA256Metadata->addOperand(SHA256Node); - return SHA256Metadata; - } - bool run(Module &M, bool IsLTO) { parseAnnotations(M); @@ -152,14 +136,9 @@ class ProteusJitPassImpl { // For device compilation, just extract the module IR of device code // and return. if (isDeviceCompilation(M)) { - // Calling the SHA256 on top of the module, as `hash_code` is not - // deterministic across executions - auto proteusHashMDNode = - addSHA256AsMetadata(M, getModuleBitcodeSHA256(M)); + DEBUG(dump(M, "device", IsLTO ? "lto-before-proteus" : "before-proteus")); emitJitModuleDevice(M, IsLTO); - // The AOT compilation does not need the SHA, so we delete it - M.eraseNamedMetadata(proteusHashMDNode); - + DEBUG(dump(M, "device", IsLTO ? "lto-after-proteus" : "after-proteus")); return true; } @@ -199,15 +178,7 @@ class ProteusJitPassImpl { if (verifyModule(M, &errs())) FATAL_ERROR("Broken original module found, compilation aborted!"); - std::filesystem::path ModulePath(M.getSourceFileName()); - std::filesystem::path filename(M.getSourceFileName()); - std::string rrBC(Twine(filename.filename().string(), ".host.bc").str()); - std::error_code EC; - raw_fd_ostream OutBC(rrBC, EC); - if (EC) - throw std::runtime_error("Cannot open device code " + rrBC); - OutBC << M; - OutBC.close(); + DEBUG(dump(M, "host", "after-proteus")); return true; } @@ -226,6 +197,19 @@ class ProteusJitPassImpl { std::string ModuleIR; }; + void dump(Module M, StringRef device, StringRef phase) { + std::filesystem::path ModulePath(M.getSourceFileName()); + std::filesystem::path filename(M.getSourceFileName()); + std::string rrBC( + Twine(filename.filename().string() + "." + device + "." + phase).str()); + std::error_code EC; + raw_fd_ostream OutBC(rrBC, EC); + if (EC) + throw std::runtime_error("Cannot open device code " + rrBC); + OutBC << M; + OutBC.close(); + } + MapVector JitFunctionInfoMap; DenseMap StubToKernelMap; SmallPtrSet ModuleDeviceKernels; @@ -1192,22 +1176,6 @@ class ProteusJitPassImpl { } } } - - // Function to get the SHA-256 checksum of a Module's bitcode - std::string getModuleBitcodeSHA256(const Module &M) { - std::string BitcodeStr; - llvm::raw_string_ostream BitcodeStream(BitcodeStr); - - WriteBitcodeToFile(M, BitcodeStream); - BitcodeStream.flush(); - - llvm::SHA256 SHA256Hasher; - SHA256Hasher.update(BitcodeStr); // Feed the bitcode to the hasher - auto Digest = SHA256Hasher.final(); // Finalize the hash computation - - // Convert the SHA-256 result to a human-readable hexadecimal string - return llvm::toHex(Digest); - } }; // New PM implementation. From 7d77c58a97286a0f1bda84dbce338dfa7db9f475 Mon Sep 17 00:00:00 2001 From: koparasy Date: Wed, 18 Dec 2024 16:36:21 -0800 Subject: [PATCH 5/5] Working HIP --- lib/JitCache.hpp | 3 +- lib/JitEngineDevice.hpp | 39 +++++++++++-------- lib/JitEngineDeviceHIP.cpp | 80 +++++++++++++++++++++++++++----------- lib/JitEngineDeviceHIP.hpp | 10 +++-- pass/ProteusPass.cpp | 39 +++++++++++++------ 5 files changed, 116 insertions(+), 55 deletions(-) diff --git a/lib/JitCache.hpp b/lib/JitCache.hpp index 075a2f6e..d3c856fb 100644 --- a/lib/JitCache.hpp +++ b/lib/JitCache.hpp @@ -34,8 +34,9 @@ inline hash_code hash_value(const proteus::RuntimeConstant &RC) { template class JitCache { public: + template uint64_t hash(StringRef ModuleUniqueId, StringRef KernelHash, - StringRef KernelName, const dim3 &GridDim, const dim3 &BlockDim, + StringRef KernelName, const T &GridDim, const T &BlockDim, const RuntimeConstant *RC, int NumRuntimeConstants) const { ArrayRef Data(RC, NumRuntimeConstants); auto HashValue = hash_combine(ModuleUniqueId, KernelHash, KernelName, diff --git a/lib/JitEngineDevice.hpp b/lib/JitEngineDevice.hpp index 485a9937..1655f9a3 100644 --- a/lib/JitEngineDevice.hpp +++ b/lib/JitEngineDevice.hpp @@ -66,12 +66,14 @@ class JITKernelInfo { int32_t NumRCs; public: - JITKernelInfo(const char *SHA256, char const *Name, int32_t *RCIndices, + JITKernelInfo(const char *_SHA256, char const *Name, int32_t *RCIndices, int32_t *RCTypes, int32_t NumRCs) - : SHA256(SHA256, 32), Name(Name), - RCIndices{ArrayRef{RCIndices, static_cast(NumRCs)}}, + : Name(Name), RCIndices{ArrayRef{RCIndices, static_cast(NumRCs)}}, RCTypes{ArrayRef{RCTypes, static_cast(NumRCs)}}, - NumRCs(NumRCs) {} + NumRCs(NumRCs) { + if (_SHA256) + SHA256 = StringRef(_SHA256, 32); + } JITKernelInfo() : Name(nullptr), NumRCs(0), RCIndices(), RCTypes() {} const auto &getName() const { return Name; } @@ -79,6 +81,7 @@ class JITKernelInfo { const auto &getRCTypes() const { return RCTypes; } const auto &getNumRCs() const { return NumRCs; } const auto &getSHA256() const { return SHA256; } + void setSHA256(StringRef sha) { SHA256 = StringRef(sha); } }; struct FatbinWrapper_t { @@ -318,7 +321,8 @@ template class JitEngineDevice : protected JitEngine { llvm::SHA256 sha256; auto ExePath = std::filesystem::canonical("/proc/self/exe"); - std::cout << "Reading file from path " << ExePath.string() << "\n"; + DBG(Logger::logs("proteus") + << "Reading file from path " << ExePath.string() << "\n"); auto bufferOrErr = MemoryBuffer::getFile(ExePath.string()); if (!bufferOrErr) { @@ -363,14 +367,13 @@ template class JitEngineDevice : protected JitEngine { ProteusCtx = std::make_unique(); ProteusDeviceBinHash = computeDeviceFatBinHash(); DBG(Logger::logs("proteus") - << "Device Bin Hash is " << ProteusDeviceBinHash << "\n"); + << "Persistent Hash is " << ProteusDeviceBinHash << "\n"); } ~JitEngineDevice() { CodeCache.printStats(); StorageCache.printStats(); - // Note: We manually clear or unique_ptr to Modules before the destructor - // releases the ProteusCtx. - + // Note: We manually the LinkedBitCodes manually. Otherwise we have a + // deconstruction fiasco, ProteusCtx is destroyed before the modules KernelToLinkedBitcode.clear(); } @@ -386,8 +389,6 @@ template class JitEngineDevice : protected JitEngine { std::unique_ptr ProteusCtx; std::string ProteusDeviceBinHash; -private: - // This map is private and only accessible via the API. DenseMap JITKernelInfoMap; }; @@ -514,6 +515,14 @@ JitEngineDevice::compileAndRun( return launchKernelDirect(Kernel, GridDim, BlockDim, KernelArgs, ShmemSize, Stream); + // This needs to happen early. In HIP we need to have parsed the IR at + // least once to get access to the "KernelHash". ExtractBitCode populates + // many entries on the KernelToLinkedBitcode. Thus this call will + // asymptotically return always true. With a worst case scenario being called + // number of unique annotated kernels. + if (!KernelToLinkedBitcode.contains(Kernel)) + extractDeviceBitcode(KernelName, Kernel); + SmallVector RCsVec; getRuntimeConstantValues(KernelArgs, RCIndices, RCTypes, RCsVec); @@ -580,12 +589,8 @@ JitEngineDevice::compileAndRun( } } - if (!KernelToLinkedBitcode.contains(Kernel)) - extractDeviceBitcode(KernelName, Kernel); - - // We need to clone, The JitModule will be specialized later, and we need - // the one generic one in KernelToLinkedBitcode to be a generic version prior - // specialization. + // We clone the JitModule as we need a copy of a generic, non specialized one + // in the KernelToLinkedBitcode auto JitModule = llvm::CloneModule(*KernelToLinkedBitcode[Kernel]); specializeIR(*JitModule, KernelName, Suffix, BlockDim, GridDim, RCIndices, diff --git a/lib/JitEngineDeviceHIP.cpp b/lib/JitEngineDeviceHIP.cpp index c618eddb..dc61b545 100644 --- a/lib/JitEngineDeviceHIP.cpp +++ b/lib/JitEngineDeviceHIP.cpp @@ -19,6 +19,7 @@ #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/TargetSelect.h" +#include "JitEngineDevice.hpp" #include "JitEngineDeviceHIP.hpp" #include "Utils.h" @@ -38,11 +39,14 @@ JitEngineDeviceHIP &JitEngineDeviceHIP::instance() { return Jit; } -std::unique_ptr -JitEngineDeviceHIP::extractDeviceBitcode(StringRef KernelName, void *Kernel) { +void JitEngineDeviceHIP::extractDeviceBitcode(StringRef KernelName, + void *Kernel) { constexpr char OFFLOAD_BUNDLER_MAGIC_STR[] = "__CLANG_OFFLOAD_BUNDLE__"; size_t Pos = 0; + if (KernelToLinkedBitcode.contains(Kernel)) + return; + if (!KernelToHandleMap.contains(Kernel)) FATAL_ERROR("Expected Kerne in map"); @@ -106,11 +110,9 @@ JitEngineDeviceHIP::extractDeviceBitcode(StringRef KernelName, void *Kernel) { ArrayRef DeviceBitcode; SmallVector> LinkedModules; - auto Ctx = std::make_unique(); - auto JitModule = std::make_unique("JitModule", *Ctx); - - auto extractModuleFromSection = [&DeviceElf, &Ctx](auto &Section, - StringRef SectionName) { + auto extractModuleFromSection = [&DeviceElf](auto &Section, + StringRef SectionName, + LLVMContext &Ctx) { ArrayRef BitcodeData; auto SectionContents = DeviceElf->getSectionContents(Section); if (SectionContents.takeError()) @@ -120,7 +122,7 @@ JitEngineDeviceHIP::extractDeviceBitcode(StringRef KernelName, void *Kernel) { BitcodeData.size()}; SMDiagnostic Err; - auto M = parseIR(MemoryBufferRef{Bitcode, SectionName}, Err, *Ctx); + auto M = parseIR(MemoryBufferRef{Bitcode, SectionName}, Err, Ctx); if (!M) FATAL_ERROR("unexpected"); @@ -135,29 +137,60 @@ JitEngineDeviceHIP::extractDeviceBitcode(StringRef KernelName, void *Kernel) { if (SectionName.takeError()) FATAL_ERROR("Error reading section name"); DBG(Logger::logs("proteus") << "SectionName " << *SectionName << "\n"); + const std::string BitCodePrefix(".jit.bitcode."); + if (!SectionName->starts_with(BitCodePrefix)) + continue; - if (!SectionName->starts_with(".jit.bitcode")) + // Skip cause we will handle later. + if (SectionName->equals(".jit.bitcode.lto")) continue; - auto M = extractModuleFromSection(Section, *SectionName); + auto SHA256 = SectionName->substr(BitCodePrefix.size()); + + auto M = extractModuleFromSection(Section, *SectionName, *ProteusCtx); + DenseSet KernelNames; + for (auto &Func : *M) { + if (Func.getCallingConv() == CallingConv::AMDGPU_KERNEL) + KernelNames.insert(Func.getName()); + } - if (SectionName->equals(".jit.bitcode.lto")) { - LinkedModules.clear(); - LinkedModules.push_back(std::move(M)); - break; - } else { - LinkedModules.push_back(std::move(M)); + for (auto &[Kernel, KernelInfo] : JITKernelInfoMap) { + auto it = KernelNames.find(KernelInfo.getName()); + if (it != KernelNames.end()) { + KernelInfo.setSHA256(SHA256); + } } + + LinkedModules.push_back(std::move(M)); } - linkJitModule(JitModule.get(), Ctx.get(), KernelName, LinkedModules); + for (auto Section : *Sections) { + auto SectionName = DeviceElf->getSectionName(Section); + if (SectionName.takeError()) + FATAL_ERROR("Error reading section name"); + + if (!SectionName->equals(".jit.bitcode.lto")) + continue; - std::string LinkedDeviceBitcode; - raw_string_ostream OS(LinkedDeviceBitcode); - WriteBitcodeToFile(*JitModule.get(), OS); - OS.flush(); + auto M = extractModuleFromSection(Section, *SectionName, *ProteusCtx); + LinkedModules.clear(); + LinkedModules.push_back(std::move(M)); + break; + } + + auto JitModule = + std::shared_ptr(linkJitModule(KernelName, LinkedModules)); + + for (const auto &KV : KernelToHandleMap) { + // All kernels included in this collection of modules will have an + // identical non specialized IR file. Map all Kernels, to this generic IR + // file + if (KV.second == Handle) { + KernelToLinkedBitcode.try_emplace(KV.first, JitModule); + } + } - return MemoryBuffer::getMemBufferCopy(StringRef(LinkedDeviceBitcode)); + return; } void JitEngineDeviceHIP::setLaunchBoundsForKernel(Module &M, Function &F, @@ -208,7 +241,8 @@ JitEngineDeviceHIP::codegenObject(Module &M, StringRef DeviceArch) { (void **)JITOptionsValues, &hip_link_state_ptr)); // NOTE: the following version of te code does not set options. - // hiprtcErrCheck(hiprtcLinkCreate(0, nullptr, nullptr, &hip_link_state_ptr)); + // hiprtcErrCheck(hiprtcLinkCreate(0, nullptr, nullptr, + // &hip_link_state_ptr)); hiprtcErrCheck(hiprtcLinkAddData( hip_link_state_ptr, HIPRTC_JIT_INPUT_LLVM_BITCODE, diff --git a/lib/JitEngineDeviceHIP.hpp b/lib/JitEngineDeviceHIP.hpp index 09c5a237..b6f2ac20 100644 --- a/lib/JitEngineDeviceHIP.hpp +++ b/lib/JitEngineDeviceHIP.hpp @@ -77,7 +77,12 @@ class JitEngineDeviceHIP : public JitEngineDevice { return "_ZNK17__HIP_CoordinatesI14__HIP_BlockIdxE3__ZcvjEv"; }; - static const char *getFatBinSectionName() { return ".hip_fatbin"; } + static bool HashSection(StringRef sectionName) { + static DenseSet Sections{".hip_fatbin", ".hipFatBinSegment", + ".hip_gpubin_handle"}; + return Sections.contains(sectionName); + ; + } void *resolveDeviceGlobalAddr(const void *Addr); @@ -86,8 +91,7 @@ class JitEngineDeviceHIP : public JitEngineDevice { void setKernelDims(Module &M, dim3 &GridDim, dim3 &BlockDim); - std::unique_ptr extractDeviceBitcode(StringRef KernelName, - void *Kernel); + void extractDeviceBitcode(StringRef KernelName, void *Kernel); std::unique_ptr codegenObject(Module &M, StringRef DeviceArch); diff --git a/pass/ProteusPass.cpp b/pass/ProteusPass.cpp index 36ad65b1..42d1541a 100644 --- a/pass/ProteusPass.cpp +++ b/pass/ProteusPass.cpp @@ -58,7 +58,6 @@ #include #include #include -#include #include #include #include @@ -88,6 +87,7 @@ constexpr char const *RegisterFunctionName = "__hipRegisterFunction"; constexpr char const *LaunchFunctionName = "hipLaunchKernel"; constexpr char const *RegisterVarName = "__hipRegisterVar"; constexpr char const *RegisterFatBinaryName = "__hipRegisterFatBinary"; +constexpr char const *FatBinWrapperSymbolName = "__hip_fatbin_wrapper"; #elif ENABLE_CUDA constexpr char const *RegisterFunctionName = "__cudaRegisterFunction"; constexpr char const *LaunchFunctionName = "cudaLaunchKernel"; @@ -99,6 +99,7 @@ constexpr char const *RegisterFunctionName = nullptr; constexpr char const *LaunchFunctionName = nullptr; constexpr char const *RegisterVarName = nullptr; constexpr char const *RegisterFatBinaryName = nullptr; +constexpr char const *FatBinWrapperSymbolName = nullptr; #endif using namespace llvm; @@ -146,6 +147,8 @@ class ProteusJitPassImpl { // Host compilation // ================ + DEBUG(dump(M, "host", IsLTO ? "lto-before-proteus" : "before-proteus")); + instrumentRegisterLinkedBinary(M); instrumentRegisterFatBinary(M); instrumentRegisterFatBinaryEnd(M); @@ -178,7 +181,7 @@ class ProteusJitPassImpl { if (verifyModule(M, &errs())) FATAL_ERROR("Broken original module found, compilation aborted!"); - DEBUG(dump(M, "host", "after-proteus")); + DEBUG(dump(M, "host", IsLTO ? "lto-after-proteus" : "after-proteus")); return true; } @@ -197,11 +200,12 @@ class ProteusJitPassImpl { std::string ModuleIR; }; - void dump(Module M, StringRef device, StringRef phase) { + void dump(Module &M, StringRef device, StringRef phase) { std::filesystem::path ModulePath(M.getSourceFileName()); std::filesystem::path filename(M.getSourceFileName()); std::string rrBC( - Twine(filename.filename().string() + "." + device + "." + phase).str()); + Twine(filename.filename().string() + "." + device + "." + phase + ".bc") + .str()); std::error_code EC; raw_fd_ostream OutBC(rrBC, EC); if (EC) @@ -468,7 +472,19 @@ class ProteusJitPassImpl { new GlobalVariable(M, JitModule->getType(), /* isConstant */ true, GlobalValue::ExternalLinkage, JitModule, GVName); appendToUsed(M, {GV}); - GV->setSection(".jit.bitcode" + (IsLTO ? ".lto" : getUniqueModuleId(&M))); + + auto ModHash = [&M] { + std::string BitcodeStr; + llvm::raw_string_ostream BitcodeStream(BitcodeStr); + WriteBitcodeToFile(M, BitcodeStream); + BitcodeStream.flush(); + llvm::SHA256 SHA256Hasher; + SHA256Hasher.update(BitcodeStr); + auto Digest = SHA256Hasher.final(); + return llvm::toHex(Digest); + }(); + + GV->setSection(".jit.bitcode." + (IsLTO ? "lto" : ModHash)); DEBUG(Logger::logs("proteus-pass") << "Emit jit bitcode GV " << GVName << "\n"); } @@ -1070,7 +1086,6 @@ class ProteusJitPassImpl { auto SHA256Global = getOrCreateModuleSHAGlobal(M); Value *SHAValue = Builder.CreateBitCast(SHA256Global, PtrTy); - SHAValue->dump(); constexpr int StubOperand = 1; Builder.CreateCall( @@ -1081,9 +1096,11 @@ class ProteusJitPassImpl { } } - GlobalVariable *computeSHA256(Module &M, GlobalVariable *GV) { - assert(GV->hasInitializer() && - "Global Variable must have initializer to compute the SHA256"); + Value *computeSHA256(Module &M, GlobalVariable *GV) { + if (!GV->hasInitializer()) { + PointerType *OpaquePtr = PointerType::get(M.getContext(), 0); + return ConstantPointerNull::get(OpaquePtr); + } LLVMContext &Context = M.getContext(); auto *LLVMDeviceIR = dyn_cast(GV->getInitializer()); StringRef Data = LLVMDeviceIR->getRawDataValues(); @@ -1108,8 +1125,8 @@ class ProteusJitPassImpl { return HashGlobal; } - GlobalVariable *getOrCreateModuleSHAGlobal(Module &M) { - GlobalVariable *ProteusSHA256GV = M.getGlobalVariable("sha256_hash"); + Value *getOrCreateModuleSHAGlobal(Module &M) { + Value *ProteusSHA256GV = M.getGlobalVariable("sha256_hash"); if (ProteusSHA256GV != nullptr) return ProteusSHA256GV;