From 85208a7e808eadae60576bcb8493b38b344afd3c Mon Sep 17 00:00:00 2001 From: koparasy Date: Wed, 18 Dec 2024 12:05:06 -0800 Subject: [PATCH] Correct cuda implementation --- lib/CompilerInterfaceDevice.cpp | 2 - lib/JitCache.hpp | 12 +-- lib/JitEngineDevice.hpp | 140 ++++++++++---------------------- lib/JitEngineDeviceCUDA.cpp | 36 +++----- lib/JitEngineDeviceCUDA.hpp | 13 +-- pass/ProteusPass.cpp | 64 ++++----------- 6 files changed, 81 insertions(+), 186 deletions(-) diff --git a/lib/CompilerInterfaceDevice.cpp b/lib/CompilerInterfaceDevice.cpp index 58035eb2..b93ea518 100644 --- a/lib/CompilerInterfaceDevice.cpp +++ b/lib/CompilerInterfaceDevice.cpp @@ -52,8 +52,6 @@ __jit_register_function(void *Handle, const char *ModuleSHA256, void *Kernel, char *KernelName, int32_t *RCIndices, int32_t *RCTypes, int32_t NumRCs) { auto &Jit = JitDeviceImplT::instance(); - std::cout << "Got ModuleSHA256 " - << llvm::toHex(StringRef(ModuleSHA256, 32).str()) << "\n"; Jit.registerFunction(Handle, ModuleSHA256, Kernel, KernelName, RCIndices, RCTypes, NumRCs); } diff --git a/lib/JitCache.hpp b/lib/JitCache.hpp index 15a93b0f..075a2f6e 100644 --- a/lib/JitCache.hpp +++ b/lib/JitCache.hpp @@ -34,13 +34,13 @@ inline hash_code hash_value(const proteus::RuntimeConstant &RC) { template class JitCache { public: - uint64_t hash(StringRef ModuleUniqueId, StringRef FnName, const dim3 &GridDim, - const dim3 &BlockDim, const RuntimeConstant *RC, - int NumRuntimeConstants) const { + uint64_t hash(StringRef ModuleUniqueId, StringRef KernelHash, + StringRef KernelName, const dim3 &GridDim, const dim3 &BlockDim, + const RuntimeConstant *RC, int NumRuntimeConstants) const { ArrayRef Data(RC, NumRuntimeConstants); - auto HashValue = - hash_combine(ModuleUniqueId, FnName, GridDim.x, GridDim.y, GridDim.z, - BlockDim.x, BlockDim.y, BlockDim.z, Data); + auto HashValue = hash_combine(ModuleUniqueId, KernelHash, KernelName, + GridDim.x, GridDim.y, GridDim.z, BlockDim.x, + BlockDim.y, BlockDim.z, Data); return HashValue; } diff --git a/lib/JitEngineDevice.hpp b/lib/JitEngineDevice.hpp index 0f53458a..485a9937 100644 --- a/lib/JitEngineDevice.hpp +++ b/lib/JitEngineDevice.hpp @@ -78,6 +78,7 @@ class JITKernelInfo { const auto &getRCIndices() const { return RCIndices; } const auto &getRCTypes() const { return RCTypes; } const auto &getNumRCs() const { return NumRCs; } + const auto &getSHA256() const { return SHA256; } }; struct FatbinWrapper_t { @@ -125,11 +126,7 @@ template class JitEngineDevice : protected JitEngine { DenseMap HandleToBinaryInfo; DenseMap KernelToHandleMap; - SmallVector>>> - SHA256HashWithBitcodes; - - DenseMap> LinkedLLVMIRModules; - DenseMap KernelToBitcodeIndex; + DenseMap> KernelToLinkedBitcode; /* @Brief After proteus initialization contains all kernels annotathed with * proteus */ DenseSet ProteusAnnotatedKernels; @@ -298,15 +295,7 @@ template class JitEngineDevice : protected JitEngine { Image); } - std::unique_ptr - createLinkedModule(ArrayRef> LinkedModules, - StringRef KernelName) { - TIMESCOPE(__FUNCTION__) - return static_cast(*this).createLinkedModule(LinkedModules, - KernelName); - } - - int extractDeviceBitcode(StringRef KernelName, void *Kernel) { + void extractDeviceBitcode(StringRef KernelName, void *Kernel) { TIMESCOPE(__FUNCTION__) return static_cast(*this).extractDeviceBitcode(KernelName, Kernel); } @@ -344,17 +333,18 @@ template class JitEngineDevice : protected JitEngine { ObjectFile &elfObj = **objOrErr; - // Step 3: Iterate through sections and get their contents for (const SectionRef §ion : elfObj.sections()) { auto nameOrErr = section.getName(); if (!nameOrErr) FATAL_ERROR("Error getting section name: "); StringRef sectionName = nameOrErr.get(); - if (sectionName.compare(ImplT::getFatBinSectionName()) != 0) + if (!ImplT::HashSection(sectionName)) continue; - // Get the contents of the section + DBG(Logger::logs("proteus") + << "Hashing section " << sectionName.str() << "\n"); + auto contentsOrErr = section.getContents(); if (!contentsOrErr) { FATAL_ERROR("Error getting section contents: "); @@ -362,11 +352,7 @@ template class JitEngineDevice : protected JitEngine { } StringRef sectionContents = contentsOrErr.get(); - // Print section name and size - outs() << "Section: " << sectionName - << ", Size: " << sectionContents.size() << " bytes\n"; sha256.update(sectionContents); - break; } auto sha256Hash = sha256.final(); return llvm::toHex(sha256Hash); @@ -376,31 +362,25 @@ template class JitEngineDevice : protected JitEngine { JitEngineDevice() { ProteusCtx = std::make_unique(); ProteusDeviceBinHash = computeDeviceFatBinHash(); - std::cout << "Device Bin Hash is " << ProteusDeviceBinHash << "\n"; + DBG(Logger::logs("proteus") + << "Device Bin Hash is " << ProteusDeviceBinHash << "\n"); } ~JitEngineDevice() { CodeCache.printStats(); StorageCache.printStats(); // Note: We manually clear or unique_ptr to Modules before the destructor // releases the ProteusCtx. - // - // Explicitly clear the LinkedLLVMIRModules - LinkedLLVMIRModules.clear(); - - // Explicitly clear SHA256HashWithBitcodes - for (auto &Entry : SHA256HashWithBitcodes) - Entry.second.clear(); - SHA256HashWithBitcodes.clear(); + + KernelToLinkedBitcode.clear(); } JitCache CodeCache; JitStorageCache StorageCache; std::string DeviceArch; std::unordered_map VarNameToDevPtr; - void linkJitModule(Module &M, StringRef KernelName, - ArrayRef> LinkedModules); - std::string - getCombinedModuleHash(ArrayRef> LinkedModules); + std::unique_ptr + linkJitModule(StringRef KernelName, + SmallVector> &LinkedModules); // All modules are associated with context, to guarantee correct lifetime. std::unique_ptr ProteusCtx; @@ -540,22 +520,17 @@ JitEngineDevice::compileAndRun( typename DeviceTraits::KernelFunction_t KernelFunc; - auto Index = KernelToBitcodeIndex.contains(Kernel) - ? KernelToBitcodeIndex[Kernel] - : extractDeviceBitcode(KernelName, Kernel); - - // I have already read the LLVM IR from the Binary. Pick the Static Hash - auto StaticHash = SHA256HashWithBitcodes[Index].first; - // TODO: This does not include the GridDims/BlockDims. We need to fix it. - auto PersistentHash = ProteusDeviceBinHash; - uint64_t DynamicHashValue = - CodeCache.hash(PersistentHash, KernelName, GridDim, BlockDim, + auto FatBinHash = ProteusDeviceBinHash; + auto KernelHash = JITKernelInfoMap[Kernel].getSHA256(); + uint64_t CombinedHash = + CodeCache.hash(FatBinHash, KernelHash, KernelName, GridDim, BlockDim, RCsVec.data(), NumRuntimeConstants); - KernelFunc = CodeCache.lookup(DynamicHashValue); - std::cout << " Function with name " << KernelName.str() << "at address " - << Kernel << " has PersistentHash " << PersistentHash - << " Static Hash:" << StaticHash - << " Dynamic Hash:" << DynamicHashValue << "\n"; + KernelFunc = CodeCache.lookup(CombinedHash); + + DBG(Logger::logs("proteus") + << " Function with name " << KernelName.str() << "at address " << Kernel + << " has PersistentHash " << FatBinHash << " Static Hash:" + << llvm::toHex(KernelHash) << " Dynamic Hash:" << CombinedHash << "\n"); // We found the kernel, execute if (KernelFunc) @@ -565,7 +540,7 @@ JitEngineDevice::compileAndRun( // NOTE: we don't need a suffix to differentiate kernels, each specialization // will be in its own module uniquely identify by HashValue. It exists only // for debugging purposes to verify that the jitted kernel executes. - std::string Suffix = mangleSuffix(DynamicHashValue); + std::string Suffix = mangleSuffix(CombinedHash); std::string KernelMangled = (KernelName + Suffix).str(); if (Config.ENV_PROTEUS_USE_STORED_CACHE) { @@ -582,9 +557,8 @@ JitEngineDevice::compileAndRun( bool HasDeviceGlobals = !VarNameToDevPtr.empty(); if (auto CacheBuf = (HasDeviceGlobals - ? StorageCache.lookupBitcode(DynamicHashValue, KernelMangled) - : StorageCache.lookupObject(DynamicHashValue, - KernelMangled))) { + ? StorageCache.lookupBitcode(CombinedHash, KernelMangled) + : StorageCache.lookupObject(CombinedHash, KernelMangled))) { std::unique_ptr ObjBuf; if (HasDeviceGlobals) { SMDiagnostic Err; @@ -598,7 +572,7 @@ JitEngineDevice::compileAndRun( auto KernelFunc = getKernelFunctionFromImage(KernelMangled, ObjBuf->getBufferStart()); - CodeCache.insert(DynamicHashValue, KernelFunc, KernelName, RCsVec.data(), + CodeCache.insert(CombinedHash, KernelFunc, KernelName, RCsVec.data(), NumRuntimeConstants); return launchKernelFunction(KernelFunc, GridDim, BlockDim, KernelArgs, @@ -606,18 +580,13 @@ JitEngineDevice::compileAndRun( } } - if (!LinkedLLVMIRModules.contains(Kernel)) { - // if we get here, we have access to the LLVM-IR of the module, but we - // have never linked everything together and internalized the symbols. - LinkedLLVMIRModules.insert( - {Kernel, - createLinkedModule(SHA256HashWithBitcodes[Index].second, KernelName)}); - } + if (!KernelToLinkedBitcode.contains(Kernel)) + extractDeviceBitcode(KernelName, Kernel); // We need to clone, The JitModule will be specialized later, and we need - // the one stored under LinkedLLVMIRModules to be a generic version prior + // the one generic one in KernelToLinkedBitcode to be a generic version prior // specialization. - auto JitModule = llvm::CloneModule(*LinkedLLVMIRModules[Kernel]); + auto JitModule = llvm::CloneModule(*KernelToLinkedBitcode[Kernel]); specializeIR(*JitModule, KernelName, Suffix, BlockDim, GridDim, RCIndices, RCsVec.data(), NumRuntimeConstants); @@ -629,18 +598,18 @@ JitEngineDevice::compileAndRun( SmallString<4096> ModuleBuffer; raw_svector_ostream ModuleBufferOS(ModuleBuffer); WriteBitcodeToFile(*JitModule, ModuleBufferOS); - StorageCache.storeBitcode(DynamicHashValue, ModuleBuffer); + StorageCache.storeBitcode(CombinedHash, ModuleBuffer); relinkGlobals(*JitModule, VarNameToDevPtr); auto ObjBuf = codegenObject(*JitModule, DeviceArch); if (Config.ENV_PROTEUS_USE_STORED_CACHE) - StorageCache.storeObject(DynamicHashValue, ObjBuf->getMemBufferRef()); + StorageCache.storeObject(CombinedHash, ObjBuf->getMemBufferRef()); KernelFunc = getKernelFunctionFromImage(KernelMangled, ObjBuf->getBufferStart()); - CodeCache.insert(DynamicHashValue, KernelFunc, KernelName, RCsVec.data(), + CodeCache.insert(CombinedHash, KernelFunc, KernelName, RCsVec.data(), NumRuntimeConstants); return launchKernelFunction(KernelFunc, GridDim, BlockDim, KernelArgs, @@ -725,46 +694,19 @@ void JitEngineDevice::registerLinkedBinary( } template -std::string JitEngineDevice::getCombinedModuleHash( - ArrayRef> LinkedModules) { - SmallVector SHA256HashCodes; - for (auto &Mod : LinkedModules) { - NamedMDNode *ProteusSHANode = - Mod->getNamedMetadata("proteus.module.sha256"); - assert(ProteusSHANode != nullptr && - "Expected non-null proteus.module.sha256 metadata"); - assert(ProteusSHANode->getNumOperands() == 1 && - "Hash MD Node should have a single operand"); - auto MDHash = ProteusSHANode->getOperand(0); - MDString *sha256 = dyn_cast(MDHash->getOperand(0)); - if (!sha256) { - FATAL_ERROR("Could not read sha256 from module\n"); - } - SHA256HashCodes.push_back(sha256->getString().str()); - Mod->eraseNamedMetadata(ProteusSHANode); - } - - std::sort(SHA256HashCodes.begin(), SHA256HashCodes.end()); - std::string combinedHash; - for (auto hash : SHA256HashCodes) { - combinedHash += hash; - } - return combinedHash; -} - -template -void JitEngineDevice::linkJitModule( - Module &M, StringRef KernelName, - ArrayRef> LinkedModules) { +std::unique_ptr JitEngineDevice::linkJitModule( + StringRef KernelName, SmallVector> &LinkedModules) { if (LinkedModules.empty()) FATAL_ERROR("Expected jit module"); - Linker IRLinker(M); - for (auto &LinkedM : llvm::reverse(LinkedModules)) { + auto LinkedModule = std::make_unique("JitModule", *ProteusCtx); + Linker IRLinker(*LinkedModule); + for (auto &LinkedM : LinkedModules) { // Returns true if linking failed. if (IRLinker.linkInModule(std::move(LinkedM))) FATAL_ERROR("Linking failed"); } + return LinkedModule; } } // namespace proteus diff --git a/lib/JitEngineDeviceCUDA.cpp b/lib/JitEngineDeviceCUDA.cpp index f7c8d3d7..c54f6151 100644 --- a/lib/JitEngineDeviceCUDA.cpp +++ b/lib/JitEngineDeviceCUDA.cpp @@ -66,26 +66,18 @@ void JitEngineDeviceCUDA::extractLinkedBitcode( if (!M) FATAL_ERROR("unexpected"); - LinkedModules.emplace_back(std::move(M)); + LinkedModules.push_back(std::move(M)); } -std::unique_ptr JitEngineDeviceCUDA::createLinkedModule( - ArrayRef> LinkedModules, - StringRef KernelName) { - auto JitModule = std::make_unique("JitModule", *ProteusCtx); - linkJitModule(*JitModule, KernelName, LinkedModules); - return std::move(JitModule); -} - -int JitEngineDeviceCUDA::extractDeviceBitcode(StringRef KernelName, - void *Kernel) { +void JitEngineDeviceCUDA::extractDeviceBitcode(StringRef KernelName, + void *Kernel) { CUmodule CUMod; CUdeviceptr DevPtr; size_t Bytes; - if (KernelToBitcodeIndex.contains(Kernel)) - return KernelToBitcodeIndex[Kernel]; + if (KernelToLinkedBitcode.contains(Kernel)) + return; SmallVector> LinkedModules; if (!KernelToHandleMap.contains(Kernel)) @@ -111,23 +103,18 @@ int JitEngineDeviceCUDA::extractDeviceBitcode(StringRef KernelName, cuErrCheck(cuModuleUnload(CUMod)); - // Store the linked modules. For future accesses - int index = SHA256HashWithBitcodes.size(); - SHA256HashWithBitcodes.push_back( - std::make_pair(getCombinedModuleHash(LinkedModules), - SmallVector>())); - // Iterate and pop elements - for (auto it = LinkedModules.rbegin(); it != LinkedModules.rend(); ++it) { - SHA256HashWithBitcodes[index].second.push_back(std::move(*it)); - } + auto JitModule = + std::shared_ptr(linkJitModule(KernelName, LinkedModules)); for (const auto &KV : KernelToHandleMap) { + // All kernels included in this collection of modules will have an identical + // non specialized IR file. Map all Kernels, to this generic IR file if (KV.second == Handle) { - KernelToBitcodeIndex.try_emplace(KV.first, index); + KernelToLinkedBitcode.try_emplace(KV.first, JitModule); } } - return index; + return; } void JitEngineDeviceCUDA::setLaunchBoundsForKernel(Module &M, Function &F, @@ -242,7 +229,6 @@ JitEngineDeviceCUDA::codegenObject(Module &M, StringRef DeviceArch) { nvPTXCompilerErrCheck( nvPTXCompilerCreate(&PTXCompiler, PTXStr.size(), PTXStr.data())); std::string ArchOpt = ("--gpu-name=" + DeviceArch).str(); - std::string RDCOption = ""; if (!GlobalLinkedBinaries.empty()) RDCOption = "-c"; diff --git a/lib/JitEngineDeviceCUDA.hpp b/lib/JitEngineDeviceCUDA.hpp index 11cc83e9..b9ec7e95 100644 --- a/lib/JitEngineDeviceCUDA.hpp +++ b/lib/JitEngineDeviceCUDA.hpp @@ -15,6 +15,7 @@ #include "Utils.h" #include #include +#include namespace proteus { @@ -79,18 +80,18 @@ class JitEngineDeviceCUDA : public JitEngineDevice { return "llvm.nvvm.read.ptx.sreg.tid.z"; }; - static const char *getFatBinSectionName() { return ".nv_fatbin"; } + static bool HashSection(StringRef sectionName) { + static DenseSet Sections{".nv_fatbin", ".nvFatBinSegment", + "__nv_relfatbin", "__nv_module_id"}; + return Sections.contains(sectionName); + } void *resolveDeviceGlobalAddr(const void *Addr); void setLaunchBoundsForKernel(Module &M, Function &F, size_t GridSize, int BlockSize); - int extractDeviceBitcode(StringRef KernelName, void *Kernel); - - std::unique_ptr - createLinkedModule(ArrayRef> LinkedModules, - StringRef KernelName); + void extractDeviceBitcode(StringRef KernelName, void *Kernel); void codegenPTX(Module &M, StringRef DeviceArch, SmallVectorImpl &PTXStr); diff --git a/pass/ProteusPass.cpp b/pass/ProteusPass.cpp index 9a2c791c..36ad65b1 100644 --- a/pass/ProteusPass.cpp +++ b/pass/ProteusPass.cpp @@ -122,22 +122,6 @@ class ProteusJitPassImpl { StructType::create({Int128Ty, Int32Ty}, "struct.args", true); } - // Function to attach the SHA-256 checksum as custom metadata to the Module - NamedMDNode *addSHA256AsMetadata(Module &M, - const std::string &SHA256Checksum) { - // Get or create a named metadata node - NamedMDNode *SHA256Metadata = - M.getOrInsertNamedMetadata("proteus.module.sha256"); - - // Create a metadata node with the SHA-256 hash as a string - MDNode *SHA256Node = MDNode::get( - M.getContext(), MDString::get(M.getContext(), SHA256Checksum)); - - // Add the metadata node as an operand to the named metadata - SHA256Metadata->addOperand(SHA256Node); - return SHA256Metadata; - } - bool run(Module &M, bool IsLTO) { parseAnnotations(M); @@ -152,14 +136,9 @@ class ProteusJitPassImpl { // For device compilation, just extract the module IR of device code // and return. if (isDeviceCompilation(M)) { - // Calling the SHA256 on top of the module, as `hash_code` is not - // deterministic across executions - auto proteusHashMDNode = - addSHA256AsMetadata(M, getModuleBitcodeSHA256(M)); + DEBUG(dump(M, "device", IsLTO ? "lto-before-proteus" : "before-proteus")); emitJitModuleDevice(M, IsLTO); - // The AOT compilation does not need the SHA, so we delete it - M.eraseNamedMetadata(proteusHashMDNode); - + DEBUG(dump(M, "device", IsLTO ? "lto-after-proteus" : "after-proteus")); return true; } @@ -199,15 +178,7 @@ class ProteusJitPassImpl { if (verifyModule(M, &errs())) FATAL_ERROR("Broken original module found, compilation aborted!"); - std::filesystem::path ModulePath(M.getSourceFileName()); - std::filesystem::path filename(M.getSourceFileName()); - std::string rrBC(Twine(filename.filename().string(), ".host.bc").str()); - std::error_code EC; - raw_fd_ostream OutBC(rrBC, EC); - if (EC) - throw std::runtime_error("Cannot open device code " + rrBC); - OutBC << M; - OutBC.close(); + DEBUG(dump(M, "host", "after-proteus")); return true; } @@ -226,6 +197,19 @@ class ProteusJitPassImpl { std::string ModuleIR; }; + void dump(Module M, StringRef device, StringRef phase) { + std::filesystem::path ModulePath(M.getSourceFileName()); + std::filesystem::path filename(M.getSourceFileName()); + std::string rrBC( + Twine(filename.filename().string() + "." + device + "." + phase).str()); + std::error_code EC; + raw_fd_ostream OutBC(rrBC, EC); + if (EC) + throw std::runtime_error("Cannot open device code " + rrBC); + OutBC << M; + OutBC.close(); + } + MapVector JitFunctionInfoMap; DenseMap StubToKernelMap; SmallPtrSet ModuleDeviceKernels; @@ -1192,22 +1176,6 @@ class ProteusJitPassImpl { } } } - - // Function to get the SHA-256 checksum of a Module's bitcode - std::string getModuleBitcodeSHA256(const Module &M) { - std::string BitcodeStr; - llvm::raw_string_ostream BitcodeStream(BitcodeStr); - - WriteBitcodeToFile(M, BitcodeStream); - BitcodeStream.flush(); - - llvm::SHA256 SHA256Hasher; - SHA256Hasher.update(BitcodeStr); // Feed the bitcode to the hasher - auto Digest = SHA256Hasher.final(); // Finalize the hash computation - - // Convert the SHA-256 result to a human-readable hexadecimal string - return llvm::toHex(Digest); - } }; // New PM implementation.