Skip to content

Commit

Permalink
Compute SHA256 of module
Browse files Browse the repository at this point in the history
  • Loading branch information
koparasy committed Dec 18, 2024
1 parent ddfb29b commit 8764ee1
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 17 deletions.
12 changes: 9 additions & 3 deletions lib/CompilerInterfaceDevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

#include "CompilerInterfaceDevice.h"
#include "JitEngineDevice.hpp"
#include <llvm/ADT/StringExtras.h>
#include <llvm/ADT/StringRef.h>

using namespace proteus;

Expand Down Expand Up @@ -46,8 +48,12 @@ __jit_register_linked_binary(void *FatbinWrapper, const char *ModuleId) {
}

extern "C" __attribute((used)) void
__jit_register_function(void *Handle, void *Kernel, char *KernelName,
int32_t *RCIndices, int32_t *RCTypes, int32_t NumRCs) {
__jit_register_function(void *Handle, const char *ModuleSHA256, void *Kernel,
char *KernelName, int32_t *RCIndices, int32_t *RCTypes,
int32_t NumRCs) {
auto &Jit = JitDeviceImplT::instance();
Jit.registerFunction(Handle, Kernel, KernelName, RCIndices, RCTypes, NumRCs);
std::cout << "Got ModuleSHA256 "
<< llvm::toHex(StringRef(ModuleSHA256, 32).str()) << "\n";
Jit.registerFunction(Handle, ModuleSHA256, Kernel, KernelName, RCIndices,
RCTypes, NumRCs);
}
19 changes: 11 additions & 8 deletions lib/JitEngineDevice.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,17 @@ namespace proteus {
using namespace llvm;

class JITKernelInfo {
StringRef SHA256;
char const *Name;
SmallVector<int32_t> RCTypes;
SmallVector<int32_t> RCIndices;
int32_t NumRCs;

public:
JITKernelInfo(char const *Name, int32_t *RCIndices, int32_t *RCTypes,
int32_t NumRCs)
: Name(Name), RCIndices{ArrayRef{RCIndices, static_cast<size_t>(NumRCs)}},
JITKernelInfo(const char *SHA256, char const *Name, int32_t *RCIndices,
int32_t *RCTypes, int32_t NumRCs)
: SHA256(SHA256, 32), Name(Name),
RCIndices{ArrayRef{RCIndices, static_cast<size_t>(NumRCs)}},
RCTypes{ArrayRef{RCTypes, static_cast<size_t>(NumRCs)}},
NumRCs(NumRCs) {}

Expand Down Expand Up @@ -108,8 +110,9 @@ template <typename ImplT> class JitEngineDevice : protected JitEngine {
void registerFatBinary(void *Handle, FatbinWrapper_t *FatbinWrapper,
const char *ModuleId);
void registerFatBinaryEnd();
void registerFunction(void *Handle, void *Kernel, char *KernelName,
int32_t *RCIndices, int32_t *RCTypes, int32_t NumRCs);
void registerFunction(void *Handle, const char *SHA256, void *Kernel,
char *KernelName, int32_t *RCIndices, int32_t *RCTypes,
int32_t NumRCs);

struct BinaryInfo {
FatbinWrapper_t *FatbinWrapper;
Expand Down Expand Up @@ -674,8 +677,8 @@ template <typename ImplT> void JitEngineDevice<ImplT>::registerFatBinaryEnd() {
}

template <typename ImplT>
void JitEngineDevice<ImplT>::registerFunction(void *Handle, void *Kernel,
char *KernelName,
void JitEngineDevice<ImplT>::registerFunction(void *Handle, const char *SHA256,
void *Kernel, char *KernelName,
int32_t *RCIndices,
int32_t *RCTypes,
int32_t NumRCs) {
Expand All @@ -690,7 +693,7 @@ void JitEngineDevice<ImplT>::registerFunction(void *Handle, void *Kernel,
ProteusAnnotatedKernels.insert(Kernel);

JITKernelInfoMap[Kernel] =
JITKernelInfo(KernelName, RCIndices, RCTypes, NumRCs);
JITKernelInfo(SHA256, KernelName, RCIndices, RCTypes, NumRCs);
}

template <typename ImplT>
Expand Down
75 changes: 69 additions & 6 deletions pass/ProteusPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@
#include "llvm/Transforms/IPO/StripSymbols.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"
#include <filesystem>
#include <llvm/ADT/SmallPtrSet.h>
#include <llvm/ADT/StringExtras.h>
#include <llvm/ADT/StringRef.h>
#include <llvm/IR/CallingConv.h>
#include <llvm/IR/Constant.h>
Expand Down Expand Up @@ -91,6 +93,7 @@ constexpr char const *RegisterFunctionName = "__cudaRegisterFunction";
constexpr char const *LaunchFunctionName = "cudaLaunchKernel";
constexpr char const *RegisterVarName = "__cudaRegisterVar";
constexpr char const *RegisterFatBinaryName = "__cudaRegisterFatBinary";
constexpr char const *FatBinWrapperSymbolName = "__cuda_fatbin_wrapper";
#else
constexpr char const *RegisterFunctionName = nullptr;
constexpr char const *LaunchFunctionName = nullptr;
Expand Down Expand Up @@ -153,8 +156,6 @@ class ProteusJitPassImpl {
// deterministic across executions
auto proteusHashMDNode =
addSHA256AsMetadata(M, getModuleBitcodeSHA256(M));
std::cout << "Proteus hash is :\n";
proteusHashMDNode->dump();
emitJitModuleDevice(M, IsLTO);
// The AOT compilation does not need the SHA, so we delete it
M.eraseNamedMetadata(proteusHashMDNode);
Expand Down Expand Up @@ -198,6 +199,16 @@ class ProteusJitPassImpl {
if (verifyModule(M, &errs()))
FATAL_ERROR("Broken original module found, compilation aborted!");

std::filesystem::path ModulePath(M.getSourceFileName());
std::filesystem::path filename(M.getSourceFileName());
std::string rrBC(Twine(filename.filename().string(), ".host.bc").str());
std::error_code EC;
raw_fd_ostream OutBC(rrBC, EC);
if (EC)
throw std::runtime_error("Cannot open device code " + rrBC);
OutBC << M;
OutBC.close();

return true;
}

Expand Down Expand Up @@ -944,14 +955,15 @@ class ProteusJitPassImpl {
FunctionCallee getJitRegisterFunctionFn(Module &M) {
// The prototype is
// __jit_register_function(void *Handle,
// const char ModuleSHA256,
// void *Kernel,
// char const *KernelName,
// int32_t *RCIndices,
// int32_t *RCTypes,
// int32_t NumRCs)
FunctionType *JitRegisterFunctionFnTy =
FunctionType::get(VoidTy, {PtrTy, PtrTy, PtrTy, PtrTy, PtrTy, Int32Ty},
/* isVarArg=*/false);
FunctionType *JitRegisterFunctionFnTy = FunctionType::get(
VoidTy, {PtrTy, PtrTy, PtrTy, PtrTy, PtrTy, PtrTy, Int32Ty},
/* isVarArg=*/false);
FunctionCallee JitRegisterKernelFn = M.getOrInsertFunction(
"__jit_register_function", JitRegisterFunctionFnTy);

Expand Down Expand Up @@ -1062,15 +1074,66 @@ class ProteusJitPassImpl {

FunctionCallee JitRegisterFunction = getJitRegisterFunctionFn(M);

auto SHA256Global = getOrCreateModuleSHAGlobal(M);
Value *SHAValue = Builder.CreateBitCast(SHA256Global, PtrTy);
SHAValue->dump();

constexpr int StubOperand = 1;
Builder.CreateCall(
JitRegisterFunction,
{RegisterCB->getArgOperand(0), RegisterCB->getArgOperand(1),
{RegisterCB->getArgOperand(0), SHAValue, RegisterCB->getArgOperand(1),
RegisterCB->getArgOperand(2), RuntimeConstantsIndicesAlloca,
RuntimeConstantsTypesAlloca, NumRCsValue});
}
}

GlobalVariable *computeSHA256(Module &M, GlobalVariable *GV) {
assert(GV->hasInitializer() &&
"Global Variable must have initializer to compute the SHA256");
LLVMContext &Context = M.getContext();
auto *LLVMDeviceIR = dyn_cast<ConstantDataArray>(GV->getInitializer());
StringRef Data = LLVMDeviceIR->getRawDataValues();
SHA256 Hash;
Hash.update(Data);
auto HashValue = Hash.result();
std::cout << "Hash value is " << llvm::toHex(HashValue) << "\n";
std::vector<Constant *> HashBytes;
for (uint8_t Byte : HashValue) {
HashBytes.push_back(ConstantInt::get(Type::getInt8Ty(Context), Byte));
}

// Create a ConstantDataArray from the hash bytes
ArrayType *HashType =
ArrayType::get(Type::getInt8Ty(Context), HashBytes.size());
Constant *HashInitializer = ConstantArray::get(HashType, HashBytes);
auto *HashGlobal = new GlobalVariable(M, HashType, /* isConstant */ true,
GlobalValue::PrivateLinkage,
HashInitializer, "sha256_hash");
// Add a null terminator to the hash bytes

return HashGlobal;
}

GlobalVariable *getOrCreateModuleSHAGlobal(Module &M) {
GlobalVariable *ProteusSHA256GV = M.getGlobalVariable("sha256_hash");
if (ProteusSHA256GV != nullptr)
return ProteusSHA256GV;

GlobalVariable *FatbinWrapper =
M.getGlobalVariable(FatBinWrapperSymbolName, true);
assert(FatbinWrapper && "Expected existing fat bin wrapper");
ConstantStruct *C =
dyn_cast<ConstantStruct>(FatbinWrapper->getInitializer());
assert(C->getType()->getNumElements() == 4 &&
"Expected four fields in fatbin wrapper struct");
constexpr int FatbinField = 2;
auto *Fatbin = C->getAggregateElement(FatbinField);
GlobalVariable *FatbinGV = dyn_cast<GlobalVariable>(Fatbin);
assert(FatbinGV && "Expected global variable for the fatbin object");
ProteusSHA256GV = computeSHA256(M, FatbinGV);
return ProteusSHA256GV;
}

void findJitVariables(Module &M) {
DEBUG(Logger::logs("proteus-pass") << "finding jit variables"
<< "\n");
Expand Down

0 comments on commit 8764ee1

Please sign in to comment.