diff --git a/lib/CompilerInterfaceDevice.cpp b/lib/CompilerInterfaceDevice.cpp index 01f4da7c..58035eb2 100644 --- a/lib/CompilerInterfaceDevice.cpp +++ b/lib/CompilerInterfaceDevice.cpp @@ -10,6 +10,8 @@ #include "CompilerInterfaceDevice.h" #include "JitEngineDevice.hpp" +#include +#include using namespace proteus; @@ -46,8 +48,12 @@ __jit_register_linked_binary(void *FatbinWrapper, const char *ModuleId) { } extern "C" __attribute((used)) void -__jit_register_function(void *Handle, void *Kernel, char *KernelName, - int32_t *RCIndices, int32_t *RCTypes, int32_t NumRCs) { +__jit_register_function(void *Handle, const char *ModuleSHA256, void *Kernel, + char *KernelName, int32_t *RCIndices, int32_t *RCTypes, + int32_t NumRCs) { auto &Jit = JitDeviceImplT::instance(); - Jit.registerFunction(Handle, Kernel, KernelName, RCIndices, RCTypes, NumRCs); + std::cout << "Got ModuleSHA256 " + << llvm::toHex(StringRef(ModuleSHA256, 32).str()) << "\n"; + Jit.registerFunction(Handle, ModuleSHA256, Kernel, KernelName, RCIndices, + RCTypes, NumRCs); } diff --git a/lib/JitEngineDevice.hpp b/lib/JitEngineDevice.hpp index 7f99bfd9..4d59ce2a 100644 --- a/lib/JitEngineDevice.hpp +++ b/lib/JitEngineDevice.hpp @@ -58,15 +58,17 @@ namespace proteus { using namespace llvm; class JITKernelInfo { + StringRef SHA256; char const *Name; SmallVector RCTypes; SmallVector RCIndices; int32_t NumRCs; public: - JITKernelInfo(char const *Name, int32_t *RCIndices, int32_t *RCTypes, - int32_t NumRCs) - : Name(Name), RCIndices{ArrayRef{RCIndices, static_cast(NumRCs)}}, + JITKernelInfo(const char *SHA256, char const *Name, int32_t *RCIndices, + int32_t *RCTypes, int32_t NumRCs) + : SHA256(SHA256, 32), Name(Name), + RCIndices{ArrayRef{RCIndices, static_cast(NumRCs)}}, RCTypes{ArrayRef{RCTypes, static_cast(NumRCs)}}, NumRCs(NumRCs) {} @@ -108,8 +110,9 @@ template class JitEngineDevice : protected JitEngine { void registerFatBinary(void *Handle, FatbinWrapper_t *FatbinWrapper, const char *ModuleId); void registerFatBinaryEnd(); - void registerFunction(void *Handle, void *Kernel, char *KernelName, - int32_t *RCIndices, int32_t *RCTypes, int32_t NumRCs); + void registerFunction(void *Handle, const char *SHA256, void *Kernel, + char *KernelName, int32_t *RCIndices, int32_t *RCTypes, + int32_t NumRCs); struct BinaryInfo { FatbinWrapper_t *FatbinWrapper; @@ -674,8 +677,8 @@ template void JitEngineDevice::registerFatBinaryEnd() { } template -void JitEngineDevice::registerFunction(void *Handle, void *Kernel, - char *KernelName, +void JitEngineDevice::registerFunction(void *Handle, const char *SHA256, + void *Kernel, char *KernelName, int32_t *RCIndices, int32_t *RCTypes, int32_t NumRCs) { @@ -690,7 +693,7 @@ void JitEngineDevice::registerFunction(void *Handle, void *Kernel, ProteusAnnotatedKernels.insert(Kernel); JITKernelInfoMap[Kernel] = - JITKernelInfo(KernelName, RCIndices, RCTypes, NumRCs); + JITKernelInfo(SHA256, KernelName, RCIndices, RCTypes, NumRCs); } template diff --git a/pass/ProteusPass.cpp b/pass/ProteusPass.cpp index 938331c6..d478269e 100644 --- a/pass/ProteusPass.cpp +++ b/pass/ProteusPass.cpp @@ -44,7 +44,9 @@ #include "llvm/Transforms/IPO/StripSymbols.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/ModuleUtils.h" +#include #include +#include #include #include #include @@ -91,6 +93,7 @@ constexpr char const *RegisterFunctionName = "__cudaRegisterFunction"; constexpr char const *LaunchFunctionName = "cudaLaunchKernel"; constexpr char const *RegisterVarName = "__cudaRegisterVar"; constexpr char const *RegisterFatBinaryName = "__cudaRegisterFatBinary"; +constexpr char const *FatBinWrapperSymbolName = "__cuda_fatbin_wrapper"; #else constexpr char const *RegisterFunctionName = nullptr; constexpr char const *LaunchFunctionName = nullptr; @@ -153,8 +156,6 @@ class ProteusJitPassImpl { // deterministic across executions auto proteusHashMDNode = addSHA256AsMetadata(M, getModuleBitcodeSHA256(M)); - std::cout << "Proteus hash is :\n"; - proteusHashMDNode->dump(); emitJitModuleDevice(M, IsLTO); // The AOT compilation does not need the SHA, so we delete it M.eraseNamedMetadata(proteusHashMDNode); @@ -198,6 +199,16 @@ class ProteusJitPassImpl { if (verifyModule(M, &errs())) FATAL_ERROR("Broken original module found, compilation aborted!"); + std::filesystem::path ModulePath(M.getSourceFileName()); + std::filesystem::path filename(M.getSourceFileName()); + std::string rrBC(Twine(filename.filename().string(), ".host.bc").str()); + std::error_code EC; + raw_fd_ostream OutBC(rrBC, EC); + if (EC) + throw std::runtime_error("Cannot open device code " + rrBC); + OutBC << M; + OutBC.close(); + return true; } @@ -944,14 +955,15 @@ class ProteusJitPassImpl { FunctionCallee getJitRegisterFunctionFn(Module &M) { // The prototype is // __jit_register_function(void *Handle, + // const char ModuleSHA256, // void *Kernel, // char const *KernelName, // int32_t *RCIndices, // int32_t *RCTypes, // int32_t NumRCs) - FunctionType *JitRegisterFunctionFnTy = - FunctionType::get(VoidTy, {PtrTy, PtrTy, PtrTy, PtrTy, PtrTy, Int32Ty}, - /* isVarArg=*/false); + FunctionType *JitRegisterFunctionFnTy = FunctionType::get( + VoidTy, {PtrTy, PtrTy, PtrTy, PtrTy, PtrTy, PtrTy, Int32Ty}, + /* isVarArg=*/false); FunctionCallee JitRegisterKernelFn = M.getOrInsertFunction( "__jit_register_function", JitRegisterFunctionFnTy); @@ -1062,15 +1074,66 @@ class ProteusJitPassImpl { FunctionCallee JitRegisterFunction = getJitRegisterFunctionFn(M); + auto SHA256Global = getOrCreateModuleSHAGlobal(M); + Value *SHAValue = Builder.CreateBitCast(SHA256Global, PtrTy); + SHAValue->dump(); + constexpr int StubOperand = 1; Builder.CreateCall( JitRegisterFunction, - {RegisterCB->getArgOperand(0), RegisterCB->getArgOperand(1), + {RegisterCB->getArgOperand(0), SHAValue, RegisterCB->getArgOperand(1), RegisterCB->getArgOperand(2), RuntimeConstantsIndicesAlloca, RuntimeConstantsTypesAlloca, NumRCsValue}); } } + GlobalVariable *computeSHA256(Module &M, GlobalVariable *GV) { + assert(GV->hasInitializer() && + "Global Variable must have initializer to compute the SHA256"); + LLVMContext &Context = M.getContext(); + auto *LLVMDeviceIR = dyn_cast(GV->getInitializer()); + StringRef Data = LLVMDeviceIR->getRawDataValues(); + SHA256 Hash; + Hash.update(Data); + auto HashValue = Hash.result(); + std::cout << "Hash value is " << llvm::toHex(HashValue) << "\n"; + std::vector HashBytes; + for (uint8_t Byte : HashValue) { + HashBytes.push_back(ConstantInt::get(Type::getInt8Ty(Context), Byte)); + } + + // Create a ConstantDataArray from the hash bytes + ArrayType *HashType = + ArrayType::get(Type::getInt8Ty(Context), HashBytes.size()); + Constant *HashInitializer = ConstantArray::get(HashType, HashBytes); + auto *HashGlobal = new GlobalVariable(M, HashType, /* isConstant */ true, + GlobalValue::PrivateLinkage, + HashInitializer, "sha256_hash"); + // Add a null terminator to the hash bytes + + return HashGlobal; + } + + GlobalVariable *getOrCreateModuleSHAGlobal(Module &M) { + GlobalVariable *ProteusSHA256GV = M.getGlobalVariable("sha256_hash"); + if (ProteusSHA256GV != nullptr) + return ProteusSHA256GV; + + GlobalVariable *FatbinWrapper = + M.getGlobalVariable(FatBinWrapperSymbolName, true); + assert(FatbinWrapper && "Expected existing fat bin wrapper"); + ConstantStruct *C = + dyn_cast(FatbinWrapper->getInitializer()); + assert(C->getType()->getNumElements() == 4 && + "Expected four fields in fatbin wrapper struct"); + constexpr int FatbinField = 2; + auto *Fatbin = C->getAggregateElement(FatbinField); + GlobalVariable *FatbinGV = dyn_cast(Fatbin); + assert(FatbinGV && "Expected global variable for the fatbin object"); + ProteusSHA256GV = computeSHA256(M, FatbinGV); + return ProteusSHA256GV; + } + void findJitVariables(Module &M) { DEBUG(Logger::logs("proteus-pass") << "finding jit variables" << "\n");