Skip to content

Commit

Permalink
Sync RuntimeConstantTy between pass and runtime
Browse files Browse the repository at this point in the history
- Use cmake to generate LLVM IR for the RuntimeConstant type
- Generate the RuntimeConstantTy in Proteuss pass using the generated LLVM IR
- Avoid errors between the pass-library interface  due to platform specific
  padding, alignment, and long double implementation
  • Loading branch information
ggeorgakoudis committed Dec 23, 2024
1 parent 54dbb1f commit 823da13
Show file tree
Hide file tree
Showing 4 changed files with 206 additions and 12 deletions.
2 changes: 1 addition & 1 deletion lib/CompilerInterfaceTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ struct RuntimeConstant {
};
RuntimeConstantType Value;
int32_t Slot{-1};
} __attribute__((packed));
};

} // namespace proteus

Expand Down
35 changes: 35 additions & 0 deletions pass/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,40 @@
# This creates an LLVM IR module that contains the type definition for
# proteus::RuntimeConstant. It embeds the module in a header used by
# ProteusPass to match its definition of the RuntimeConstant type to the one
# expected on the target platform.
set(GEN_SOURCE_CODE "#include \"${CMAKE_SOURCE_DIR}/lib/CompilerInterfaceTypes.h\"\nstatic proteus::RuntimeConstant RC;")
file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/cmake-gen.cpp "${GEN_SOURCE_CODE}")
set(OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/cmake-out.ll)
set(HEADER_FILE ${CMAKE_CURRENT_BINARY_DIR}/GenCompilerInterfaceTypes.h)

add_custom_command(
OUTPUT ${OUTPUT_FILE}
COMMAND ${CMAKE_CXX_COMPILER} ${CMAKE_CURRENT_BINARY_DIR}/cmake-gen.cpp -emit-llvm -S -o ${OUTPUT_FILE}
DEPENDS ${PROJECT_SOURCE_DIR}/lib/CompilerInterfaceTypes.h
COMMENT "Generating CompilerInterfaceTypes LLVM IR module file"
VERBATIM
)

add_custom_command(
OUTPUT ${HEADER_FILE}
COMMAND ${CMAKE_COMMAND} -DOUTPUT_FILE=${OUTPUT_FILE} -DHEADER_FILE=${HEADER_FILE} -P ${CMAKE_CURRENT_BINARY_DIR}/embed_file.cmake
DEPENDS ${OUTPUT_FILE}
COMMENT "Generating ProteusPass header with GenModule"
)

add_custom_target(GenerateRuntimeConstantTyHeader DEPENDS ${HEADER_FILE})

file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/embed_file.cmake [[
file(READ "${OUTPUT_FILE}" FILE_CONTENTS)
file(WRITE "${HEADER_FILE}" "static const char GenModule[] = R\"===(${FILE_CONTENTS})===\";")
]])

add_library(ProteusPass SHARED ProteusPass.cpp)

add_dependencies(ProteusPass GenerateRuntimeConstantTyHeader)
target_include_directories(ProteusPass
PRIVATE ${CMAKE_CURRENT_BINARY_DIR})

target_include_directories(ProteusPass
SYSTEM PRIVATE ${LLVM_INCLUDE_DIRS})

Expand Down
154 changes: 154 additions & 0 deletions pass/GenRuntimeConstantTy.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
//===-- GenRuntimeConstantTy.hpp -- Generate runtime constant type --===//
//
// Part of the Proteus Project, under the Apache License v2.0 with LLVM
// Exceptions. See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
//===----------------------------------------------------------------------===//

#ifndef PROTEUS_GEN_RUNTIME_CONSTANT_TY_HPP
#define PROTEUS_GEN_RUNTIME_CONSTANT_TY_HPP

#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"

#include "GenCompilerInterfaceTypes.h"

extern const char GenModule[];

namespace proteus {

using namespace llvm;

// Extract the RuntimeConstantTy using the LLVM IR of a cmake-compiled
// module for the target platform to ensure the type definition between the
// pass and the runtime match.
static Expected<StructType *> getRuntimeConstantTy(LLVMContext &TargetCtx) {
LLVMContext Context;
SMDiagnostic Err;
auto GenM = parseAssemblyString(GenModule, Err, Context);
if (!GenM) {
return createStringError(inconvertibleErrorCode(),
"Cannot parse generated module: " +
Err.getMessage());
}

auto MapTypeToTargetContext =
[&TargetCtx](Type *Ty,
auto &&MapTypeToTargetContext) -> Expected<Type *> {
if (Ty->isStructTy()) {
auto *StructTy = cast<StructType>(Ty);

// Handle unnamed literal struct.
if (StructTy->isLiteral()) {
SmallVector<Type *> ElementTypes;
for (Type *ElemTy : StructTy->elements()) {
auto ExpectedMappedElemTy =
MapTypeToTargetContext(ElemTy, MapTypeToTargetContext);
if (auto E = ExpectedMappedElemTy.takeError())
return createStringError(
inconvertibleErrorCode(),
"Failed to map element type in literal struct: " +
toString(std::move(E)));
ElementTypes.push_back(ExpectedMappedElemTy.get());
}
return StructType::get(TargetCtx, ElementTypes, StructTy->isPacked());
}

// Handle named struct.
StructType *ExistingType =
StructType::getTypeByName(TargetCtx, StructTy->getName());
if (ExistingType)
return ExistingType;

// Recursively populate elements.
SmallVector<Type *> ElementTypes;
for (Type *ElemTy : StructTy->elements()) {
auto ExpectedMappedElemTy =
MapTypeToTargetContext(ElemTy, MapTypeToTargetContext);
if (auto E = ExpectedMappedElemTy.takeError()) {
return createStringError(
inconvertibleErrorCode(),
"Failed to map element type in named struct: " +
toString(std::move(E)));
}
ElementTypes.push_back(ExpectedMappedElemTy.get());
}

StructType *NewStruct = StructType::create(
TargetCtx, ElementTypes, StructTy->getName(), StructTy->isPacked());
return NewStruct;
}

if (Ty->isArrayTy()) {
ArrayType *ArrayTy = cast<ArrayType>(Ty);
auto ExpectedElementType = MapTypeToTargetContext(
ArrayTy->getElementType(), MapTypeToTargetContext);
if (auto E = ExpectedElementType.takeError()) {
return createStringError(inconvertibleErrorCode(),
"Failed to map array element type: " +
toString(std::move(E)));
}

return ArrayType::get(ExpectedElementType.get(),
ArrayTy->getNumElements());
}

if (Ty->isPointerTy()) {
PointerType *PointerTy = cast<PointerType>(Ty);
return PointerType::get(TargetCtx, PointerTy->getAddressSpace());
}

if (Ty->isIntegerTy()) {
return IntegerType::get(TargetCtx, cast<IntegerType>(Ty)->getBitWidth());
}

if (Ty->isFloatingPointTy()) {
if (Ty->isHalfTy())
return Type::getHalfTy(TargetCtx);
if (Ty->isFloatTy())
return Type::getFloatTy(TargetCtx);
if (Ty->isDoubleTy())
return Type::getDoubleTy(TargetCtx);
if (Ty->isFP128Ty())
return Type::getFP128Ty(TargetCtx);
if (Ty->isX86_FP80Ty())
return Type::getX86_FP80Ty(TargetCtx);
if (Ty->isPPC_FP128Ty())
return Type::getPPC_FP128Ty(TargetCtx);
}

std::string TyStr;
raw_string_ostream OS{TyStr};
Ty->print(OS, true);
return createStringError(inconvertibleErrorCode(),
"Unsupported type: " + TyStr);
};

StructType *GenRuntimeConstantTy =
StructType::getTypeByName(Context, "struct.proteus::RuntimeConstant");
if (!GenRuntimeConstantTy)
return createStringError(inconvertibleErrorCode(),
"Expected non-null GenRuntimeConstantTy");

auto ExpectedRuntimeConstantTy =
MapTypeToTargetContext(GenRuntimeConstantTy, MapTypeToTargetContext);
if (auto E = ExpectedRuntimeConstantTy.takeError())
return createStringError(inconvertibleErrorCode(),
"Failed to map runtime constant type: " +
toString(std::move(E)));

return cast<StructType>(ExpectedRuntimeConstantTy.get());
}

} // namespace proteus

#endif
27 changes: 16 additions & 11 deletions pass/ProteusPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@

#include "../lib/CompilerInterfaceTypes.h"

#include "GenRuntimeConstantTy.hpp"

#define DEBUG_TYPE "jitpass"
#ifdef ENABLE_DEBUG
#define DEBUG(x) x
Expand Down Expand Up @@ -113,8 +115,13 @@ class ProteusJitPassImpl {
Int32Ty = Type::getInt32Ty(M.getContext());
Int64Ty = Type::getInt64Ty(M.getContext());
Int128Ty = Type::getInt128Ty(M.getContext());
RuntimeConstantTy =
StructType::create({Int128Ty, Int32Ty}, "struct.args", true);

auto ExpecedRuntimeConstantTy = getRuntimeConstantTy(M.getContext());
if (auto E = ExpecedRuntimeConstantTy.takeError())
FATAL_ERROR("Expected valid generated RuntimeConstantTy: " +
toString(std::move(E)));

RuntimeConstantTy = ExpecedRuntimeConstantTy.get();
}

bool run(Module &M, bool IsLTO) {
Expand Down Expand Up @@ -525,32 +532,30 @@ class ProteusJitPassImpl {
auto *StrIRGlobal = Builder.CreateGlobalString(JFI.ModuleIR);

// Create the runtime constants data structure passed to the jit entry.
Value *RuntimeConstantsIndicesAlloca = nullptr;
Value *RuntimeConstantsAlloca = nullptr;
if (JFI.ConstantArgs.size() > 0) {
RuntimeConstantsIndicesAlloca =
Builder.CreateAlloca(RuntimeConstantArrayTy);
RuntimeConstantsAlloca = Builder.CreateAlloca(RuntimeConstantArrayTy);
// Zero-initialize the alloca to avoid stack garbage for caching.
Builder.CreateStore(Constant::getNullValue(RuntimeConstantArrayTy),
RuntimeConstantsIndicesAlloca);
RuntimeConstantsAlloca);
for (int ArgI = 0; ArgI < JFI.ConstantArgs.size(); ++ArgI) {
auto *GEP = Builder.CreateInBoundsGEP(
RuntimeConstantArrayTy, RuntimeConstantsIndicesAlloca,
RuntimeConstantArrayTy, RuntimeConstantsAlloca,
{Builder.getInt32(0), Builder.getInt32(ArgI)});
int ArgNo = JFI.ConstantArgs[ArgI];
Builder.CreateStore(StubFn->getArg(ArgNo), GEP);
}
} else
RuntimeConstantsIndicesAlloca =
RuntimeConstantsAlloca =
Constant::getNullValue(RuntimeConstantArrayTy->getPointerTo());

assert(RuntimeConstantsIndicesAlloca &&
assert(RuntimeConstantsAlloca &&
"Expected non-null runtime constants alloca");

auto *JitFnPtr = Builder.CreateCall(
JitEntryFn,
{FnNameGlobal, StrIRGlobal, Builder.getInt32(JFI.ModuleIR.size()),
RuntimeConstantsIndicesAlloca,
Builder.getInt32(JFI.ConstantArgs.size())});
RuntimeConstantsAlloca, Builder.getInt32(JFI.ConstantArgs.size())});
SmallVector<Value *, 8> Args;
for (auto &Arg : StubFn->args())
Args.push_back(&Arg);
Expand Down

0 comments on commit 823da13

Please sign in to comment.