Skip to content

Commit

Permalink
Correct cuda implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
koparasy committed Dec 18, 2024
1 parent 1b5660c commit 85208a7
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 186 deletions.
2 changes: 0 additions & 2 deletions lib/CompilerInterfaceDevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ __jit_register_function(void *Handle, const char *ModuleSHA256, void *Kernel,
char *KernelName, int32_t *RCIndices, int32_t *RCTypes,
int32_t NumRCs) {
auto &Jit = JitDeviceImplT::instance();
std::cout << "Got ModuleSHA256 "
<< llvm::toHex(StringRef(ModuleSHA256, 32).str()) << "\n";
Jit.registerFunction(Handle, ModuleSHA256, Kernel, KernelName, RCIndices,
RCTypes, NumRCs);
}
12 changes: 6 additions & 6 deletions lib/JitCache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ inline hash_code hash_value(const proteus::RuntimeConstant &RC) {

template <typename Function_t> class JitCache {
public:
uint64_t hash(StringRef ModuleUniqueId, StringRef FnName, const dim3 &GridDim,
const dim3 &BlockDim, const RuntimeConstant *RC,
int NumRuntimeConstants) const {
uint64_t hash(StringRef ModuleUniqueId, StringRef KernelHash,
StringRef KernelName, const dim3 &GridDim, const dim3 &BlockDim,
const RuntimeConstant *RC, int NumRuntimeConstants) const {
ArrayRef<RuntimeConstant> Data(RC, NumRuntimeConstants);
auto HashValue =
hash_combine(ModuleUniqueId, FnName, GridDim.x, GridDim.y, GridDim.z,
BlockDim.x, BlockDim.y, BlockDim.z, Data);
auto HashValue = hash_combine(ModuleUniqueId, KernelHash, KernelName,
GridDim.x, GridDim.y, GridDim.z, BlockDim.x,
BlockDim.y, BlockDim.z, Data);
return HashValue;
}

Expand Down
140 changes: 41 additions & 99 deletions lib/JitEngineDevice.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class JITKernelInfo {
const auto &getRCIndices() const { return RCIndices; }
const auto &getRCTypes() const { return RCTypes; }
const auto &getNumRCs() const { return NumRCs; }
const auto &getSHA256() const { return SHA256; }
};

struct FatbinWrapper_t {
Expand Down Expand Up @@ -125,11 +126,7 @@ template <typename ImplT> class JitEngineDevice : protected JitEngine {
DenseMap<void *, BinaryInfo> HandleToBinaryInfo;
DenseMap<void *, void *> KernelToHandleMap;

SmallVector<std::pair<std::string, SmallVector<std::unique_ptr<Module>>>>
SHA256HashWithBitcodes;

DenseMap<void *, std::unique_ptr<llvm::Module>> LinkedLLVMIRModules;
DenseMap<void *, int> KernelToBitcodeIndex;
DenseMap<void *, std::shared_ptr<llvm::Module>> KernelToLinkedBitcode;
/* @Brief After proteus initialization contains all kernels annotathed with
* proteus */
DenseSet<void *> ProteusAnnotatedKernels;
Expand Down Expand Up @@ -298,15 +295,7 @@ template <typename ImplT> class JitEngineDevice : protected JitEngine {
Image);
}

std::unique_ptr<llvm::Module>
createLinkedModule(ArrayRef<std::unique_ptr<Module>> LinkedModules,
StringRef KernelName) {
TIMESCOPE(__FUNCTION__)
return static_cast<ImplT &>(*this).createLinkedModule(LinkedModules,
KernelName);
}

int extractDeviceBitcode(StringRef KernelName, void *Kernel) {
void extractDeviceBitcode(StringRef KernelName, void *Kernel) {
TIMESCOPE(__FUNCTION__)
return static_cast<ImplT &>(*this).extractDeviceBitcode(KernelName, Kernel);
}
Expand Down Expand Up @@ -344,29 +333,26 @@ template <typename ImplT> class JitEngineDevice : protected JitEngine {

ObjectFile &elfObj = **objOrErr;

// Step 3: Iterate through sections and get their contents
for (const SectionRef &section : elfObj.sections()) {
auto nameOrErr = section.getName();
if (!nameOrErr)
FATAL_ERROR("Error getting section name: ");

StringRef sectionName = nameOrErr.get();
if (sectionName.compare(ImplT::getFatBinSectionName()) != 0)
if (!ImplT::HashSection(sectionName))
continue;

// Get the contents of the section
DBG(Logger::logs("proteus")
<< "Hashing section " << sectionName.str() << "\n");

auto contentsOrErr = section.getContents();
if (!contentsOrErr) {
FATAL_ERROR("Error getting section contents: ");
continue;
}
StringRef sectionContents = contentsOrErr.get();

// Print section name and size
outs() << "Section: " << sectionName
<< ", Size: " << sectionContents.size() << " bytes\n";
sha256.update(sectionContents);
break;
}
auto sha256Hash = sha256.final();
return llvm::toHex(sha256Hash);
Expand All @@ -376,31 +362,25 @@ template <typename ImplT> class JitEngineDevice : protected JitEngine {
JitEngineDevice() {
ProteusCtx = std::make_unique<LLVMContext>();
ProteusDeviceBinHash = computeDeviceFatBinHash();
std::cout << "Device Bin Hash is " << ProteusDeviceBinHash << "\n";
DBG(Logger::logs("proteus")
<< "Device Bin Hash is " << ProteusDeviceBinHash << "\n");
}
~JitEngineDevice() {
CodeCache.printStats();
StorageCache.printStats();
// Note: We manually clear or unique_ptr to Modules before the destructor
// releases the ProteusCtx.
//
// Explicitly clear the LinkedLLVMIRModules
LinkedLLVMIRModules.clear();

// Explicitly clear SHA256HashWithBitcodes
for (auto &Entry : SHA256HashWithBitcodes)
Entry.second.clear();
SHA256HashWithBitcodes.clear();

KernelToLinkedBitcode.clear();
}

JitCache<KernelFunction_t> CodeCache;
JitStorageCache<KernelFunction_t> StorageCache;
std::string DeviceArch;
std::unordered_map<std::string, const void *> VarNameToDevPtr;
void linkJitModule(Module &M, StringRef KernelName,
ArrayRef<std::unique_ptr<Module>> LinkedModules);
std::string
getCombinedModuleHash(ArrayRef<std::unique_ptr<Module>> LinkedModules);
std::unique_ptr<Module>
linkJitModule(StringRef KernelName,
SmallVector<std::unique_ptr<Module>> &LinkedModules);

// All modules are associated with context, to guarantee correct lifetime.
std::unique_ptr<LLVMContext> ProteusCtx;
Expand Down Expand Up @@ -540,22 +520,17 @@ JitEngineDevice<ImplT>::compileAndRun(

typename DeviceTraits<ImplT>::KernelFunction_t KernelFunc;

auto Index = KernelToBitcodeIndex.contains(Kernel)
? KernelToBitcodeIndex[Kernel]
: extractDeviceBitcode(KernelName, Kernel);

// I have already read the LLVM IR from the Binary. Pick the Static Hash
auto StaticHash = SHA256HashWithBitcodes[Index].first;
// TODO: This does not include the GridDims/BlockDims. We need to fix it.
auto PersistentHash = ProteusDeviceBinHash;
uint64_t DynamicHashValue =
CodeCache.hash(PersistentHash, KernelName, GridDim, BlockDim,
auto FatBinHash = ProteusDeviceBinHash;
auto KernelHash = JITKernelInfoMap[Kernel].getSHA256();
uint64_t CombinedHash =
CodeCache.hash(FatBinHash, KernelHash, KernelName, GridDim, BlockDim,
RCsVec.data(), NumRuntimeConstants);
KernelFunc = CodeCache.lookup(DynamicHashValue);
std::cout << " Function with name " << KernelName.str() << "at address "
<< Kernel << " has PersistentHash " << PersistentHash
<< " Static Hash:" << StaticHash
<< " Dynamic Hash:" << DynamicHashValue << "\n";
KernelFunc = CodeCache.lookup(CombinedHash);

DBG(Logger::logs("proteus")
<< " Function with name " << KernelName.str() << "at address " << Kernel
<< " has PersistentHash " << FatBinHash << " Static Hash:"
<< llvm::toHex(KernelHash) << " Dynamic Hash:" << CombinedHash << "\n");

// We found the kernel, execute
if (KernelFunc)
Expand All @@ -565,7 +540,7 @@ JitEngineDevice<ImplT>::compileAndRun(
// NOTE: we don't need a suffix to differentiate kernels, each specialization
// will be in its own module uniquely identify by HashValue. It exists only
// for debugging purposes to verify that the jitted kernel executes.
std::string Suffix = mangleSuffix(DynamicHashValue);
std::string Suffix = mangleSuffix(CombinedHash);
std::string KernelMangled = (KernelName + Suffix).str();

if (Config.ENV_PROTEUS_USE_STORED_CACHE) {
Expand All @@ -582,9 +557,8 @@ JitEngineDevice<ImplT>::compileAndRun(
bool HasDeviceGlobals = !VarNameToDevPtr.empty();
if (auto CacheBuf =
(HasDeviceGlobals
? StorageCache.lookupBitcode(DynamicHashValue, KernelMangled)
: StorageCache.lookupObject(DynamicHashValue,
KernelMangled))) {
? StorageCache.lookupBitcode(CombinedHash, KernelMangled)
: StorageCache.lookupObject(CombinedHash, KernelMangled))) {
std::unique_ptr<MemoryBuffer> ObjBuf;
if (HasDeviceGlobals) {
SMDiagnostic Err;
Expand All @@ -598,26 +572,21 @@ JitEngineDevice<ImplT>::compileAndRun(
auto KernelFunc =
getKernelFunctionFromImage(KernelMangled, ObjBuf->getBufferStart());

CodeCache.insert(DynamicHashValue, KernelFunc, KernelName, RCsVec.data(),
CodeCache.insert(CombinedHash, KernelFunc, KernelName, RCsVec.data(),
NumRuntimeConstants);

return launchKernelFunction(KernelFunc, GridDim, BlockDim, KernelArgs,
ShmemSize, Stream);
}
}

if (!LinkedLLVMIRModules.contains(Kernel)) {
// if we get here, we have access to the LLVM-IR of the module, but we
// have never linked everything together and internalized the symbols.
LinkedLLVMIRModules.insert(
{Kernel,
createLinkedModule(SHA256HashWithBitcodes[Index].second, KernelName)});
}
if (!KernelToLinkedBitcode.contains(Kernel))
extractDeviceBitcode(KernelName, Kernel);

// We need to clone, The JitModule will be specialized later, and we need
// the one stored under LinkedLLVMIRModules to be a generic version prior
// the one generic one in KernelToLinkedBitcode to be a generic version prior
// specialization.
auto JitModule = llvm::CloneModule(*LinkedLLVMIRModules[Kernel]);
auto JitModule = llvm::CloneModule(*KernelToLinkedBitcode[Kernel]);

specializeIR(*JitModule, KernelName, Suffix, BlockDim, GridDim, RCIndices,
RCsVec.data(), NumRuntimeConstants);
Expand All @@ -629,18 +598,18 @@ JitEngineDevice<ImplT>::compileAndRun(
SmallString<4096> ModuleBuffer;
raw_svector_ostream ModuleBufferOS(ModuleBuffer);
WriteBitcodeToFile(*JitModule, ModuleBufferOS);
StorageCache.storeBitcode(DynamicHashValue, ModuleBuffer);
StorageCache.storeBitcode(CombinedHash, ModuleBuffer);

relinkGlobals(*JitModule, VarNameToDevPtr);

auto ObjBuf = codegenObject(*JitModule, DeviceArch);
if (Config.ENV_PROTEUS_USE_STORED_CACHE)
StorageCache.storeObject(DynamicHashValue, ObjBuf->getMemBufferRef());
StorageCache.storeObject(CombinedHash, ObjBuf->getMemBufferRef());

KernelFunc =
getKernelFunctionFromImage(KernelMangled, ObjBuf->getBufferStart());

CodeCache.insert(DynamicHashValue, KernelFunc, KernelName, RCsVec.data(),
CodeCache.insert(CombinedHash, KernelFunc, KernelName, RCsVec.data(),
NumRuntimeConstants);

return launchKernelFunction(KernelFunc, GridDim, BlockDim, KernelArgs,
Expand Down Expand Up @@ -725,46 +694,19 @@ void JitEngineDevice<ImplT>::registerLinkedBinary(
}

template <typename ImplT>
std::string JitEngineDevice<ImplT>::getCombinedModuleHash(
ArrayRef<std::unique_ptr<Module>> LinkedModules) {
SmallVector<std::string> SHA256HashCodes;
for (auto &Mod : LinkedModules) {
NamedMDNode *ProteusSHANode =
Mod->getNamedMetadata("proteus.module.sha256");
assert(ProteusSHANode != nullptr &&
"Expected non-null proteus.module.sha256 metadata");
assert(ProteusSHANode->getNumOperands() == 1 &&
"Hash MD Node should have a single operand");
auto MDHash = ProteusSHANode->getOperand(0);
MDString *sha256 = dyn_cast<MDString>(MDHash->getOperand(0));
if (!sha256) {
FATAL_ERROR("Could not read sha256 from module\n");
}
SHA256HashCodes.push_back(sha256->getString().str());
Mod->eraseNamedMetadata(ProteusSHANode);
}

std::sort(SHA256HashCodes.begin(), SHA256HashCodes.end());
std::string combinedHash;
for (auto hash : SHA256HashCodes) {
combinedHash += hash;
}
return combinedHash;
}

template <typename ImplT>
void JitEngineDevice<ImplT>::linkJitModule(
Module &M, StringRef KernelName,
ArrayRef<std::unique_ptr<Module>> LinkedModules) {
std::unique_ptr<Module> JitEngineDevice<ImplT>::linkJitModule(
StringRef KernelName, SmallVector<std::unique_ptr<Module>> &LinkedModules) {
if (LinkedModules.empty())
FATAL_ERROR("Expected jit module");

Linker IRLinker(M);
for (auto &LinkedM : llvm::reverse(LinkedModules)) {
auto LinkedModule = std::make_unique<llvm::Module>("JitModule", *ProteusCtx);
Linker IRLinker(*LinkedModule);
for (auto &LinkedM : LinkedModules) {
// Returns true if linking failed.
if (IRLinker.linkInModule(std::move(LinkedM)))
FATAL_ERROR("Linking failed");
}
return LinkedModule;
}

} // namespace proteus
Expand Down
36 changes: 11 additions & 25 deletions lib/JitEngineDeviceCUDA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,26 +66,18 @@ void JitEngineDeviceCUDA::extractLinkedBitcode(
if (!M)
FATAL_ERROR("unexpected");

LinkedModules.emplace_back(std::move(M));
LinkedModules.push_back(std::move(M));
}

std::unique_ptr<llvm::Module> JitEngineDeviceCUDA::createLinkedModule(
ArrayRef<std::unique_ptr<llvm::Module>> LinkedModules,
StringRef KernelName) {
auto JitModule = std::make_unique<llvm::Module>("JitModule", *ProteusCtx);
linkJitModule(*JitModule, KernelName, LinkedModules);
return std::move(JitModule);
}

int JitEngineDeviceCUDA::extractDeviceBitcode(StringRef KernelName,
void *Kernel) {
void JitEngineDeviceCUDA::extractDeviceBitcode(StringRef KernelName,
void *Kernel) {
CUmodule CUMod;
CUdeviceptr DevPtr;

size_t Bytes;

if (KernelToBitcodeIndex.contains(Kernel))
return KernelToBitcodeIndex[Kernel];
if (KernelToLinkedBitcode.contains(Kernel))
return;

SmallVector<std::unique_ptr<Module>> LinkedModules;
if (!KernelToHandleMap.contains(Kernel))
Expand All @@ -111,23 +103,18 @@ int JitEngineDeviceCUDA::extractDeviceBitcode(StringRef KernelName,

cuErrCheck(cuModuleUnload(CUMod));

// Store the linked modules. For future accesses
int index = SHA256HashWithBitcodes.size();
SHA256HashWithBitcodes.push_back(
std::make_pair(getCombinedModuleHash(LinkedModules),
SmallVector<std::unique_ptr<Module>>()));
// Iterate and pop elements
for (auto it = LinkedModules.rbegin(); it != LinkedModules.rend(); ++it) {
SHA256HashWithBitcodes[index].second.push_back(std::move(*it));
}
auto JitModule =
std::shared_ptr<Module>(linkJitModule(KernelName, LinkedModules));

for (const auto &KV : KernelToHandleMap) {
// All kernels included in this collection of modules will have an identical
// non specialized IR file. Map all Kernels, to this generic IR file
if (KV.second == Handle) {
KernelToBitcodeIndex.try_emplace(KV.first, index);
KernelToLinkedBitcode.try_emplace(KV.first, JitModule);
}
}

return index;
return;
}

void JitEngineDeviceCUDA::setLaunchBoundsForKernel(Module &M, Function &F,
Expand Down Expand Up @@ -242,7 +229,6 @@ JitEngineDeviceCUDA::codegenObject(Module &M, StringRef DeviceArch) {
nvPTXCompilerErrCheck(
nvPTXCompilerCreate(&PTXCompiler, PTXStr.size(), PTXStr.data()));
std::string ArchOpt = ("--gpu-name=" + DeviceArch).str();

std::string RDCOption = "";
if (!GlobalLinkedBinaries.empty())
RDCOption = "-c";
Expand Down
Loading

0 comments on commit 85208a7

Please sign in to comment.