Skip to content

Commit

Permalink
[MLIR] Fix import of calls with mismatched variadic types (#124286)
Browse files Browse the repository at this point in the history
Previously, an indirect call was incorrectly generated when
`llvm::CallBase::getCalledFunction` returned null due to a type mismatch
between the call and the function. This patch updates the code to use
`llvm::CallBase::getCalledOperand` instead.
  • Loading branch information
xlauko authored Jan 24, 2025
1 parent 3b30f20 commit 95d993a
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 32 deletions.
77 changes: 45 additions & 32 deletions mlir/lib/Target/LLVMIR/ModuleImport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1495,15 +1495,22 @@ LogicalResult ModuleImport::convertCallTypeAndOperands(
if (!callInst->getType()->isVoidTy())
types.push_back(convertType(callInst->getType()));

if (!callInst->getCalledFunction()) {
if (!allowInlineAsm ||
!isa<llvm::InlineAsm>(callInst->getCalledOperand())) {
FailureOr<Value> called = convertValue(callInst->getCalledOperand());
if (failed(called))
return failure();
operands.push_back(*called);
}
bool isInlineAsm = callInst->isInlineAsm();
if (isInlineAsm && !allowInlineAsm)
return failure();

// Cannot use isIndirectCall() here because we need to handle Constant callees
// that are not considered indirect calls by LLVM. However, in MLIR, they are
// treated as indirect calls to constant operands that need to be converted.
// Skip the callee operand if it's inline assembly, as it's handled separately
// in InlineAsmOp.
if (!isa<llvm::Function>(callInst->getCalledOperand()) && !isInlineAsm) {
FailureOr<Value> called = convertValue(callInst->getCalledOperand());
if (failed(called))
return failure();
operands.push_back(*called);
}

SmallVector<llvm::Value *> args(callInst->args());
FailureOr<SmallVector<Value>> arguments = convertValues(args);
if (failed(arguments))
Expand Down Expand Up @@ -1593,23 +1600,21 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
return success();
}
if (inst->getOpcode() == llvm::Instruction::Call) {
auto *callInst = cast<llvm::CallInst>(inst);
auto callInst = cast<llvm::CallInst>(inst);
llvm::Value *calledOperand = callInst->getCalledOperand();

SmallVector<Type> types;
SmallVector<Value> operands;
if (failed(convertCallTypeAndOperands(callInst, types, operands,
/*allowInlineAsm=*/true)))
return failure();

auto funcTy =
dyn_cast<LLVMFunctionType>(convertType(callInst->getFunctionType()));
if (!funcTy)
return failure();

if (auto asmI = dyn_cast<llvm::InlineAsm>(callInst->getCalledOperand())) {
if (auto asmI = dyn_cast<llvm::InlineAsm>(calledOperand)) {
Type resultTy = convertType(callInst->getType());
if (!resultTy)
return failure();
auto callOp = builder.create<InlineAsmOp>(
loc, funcTy.getReturnType(), operands,
builder.getStringAttr(asmI->getAsmString()),
loc, resultTy, operands, builder.getStringAttr(asmI->getAsmString()),
builder.getStringAttr(asmI->getConstraintString()),
/*has_side_effects=*/true,
/*is_align_stack=*/false, /*asm_dialect=*/nullptr,
Expand All @@ -1619,27 +1624,35 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
else
mapNoResultOp(inst, callOp);
} else {
CallOp callOp;
auto funcTy = dyn_cast<LLVMFunctionType>([&]() -> Type {
// Retrieve the real function type. For direct calls, use the callee's
// function type, as it may differ from the operand type in the case of
// variadic functions. For indirect calls, use the call function type.
if (auto callee = dyn_cast<llvm::Function>(calledOperand))
return convertType(callee->getFunctionType());
return convertType(callInst->getFunctionType());
}());

if (!funcTy)
return failure();

if (llvm::Function *callee = callInst->getCalledFunction()) {
callOp = builder.create<CallOp>(
loc, funcTy, SymbolRefAttr::get(context, callee->getName()),
operands);
} else {
callOp = builder.create<CallOp>(loc, funcTy, operands);
}
auto callOp = [&]() -> CallOp {
if (auto callee = dyn_cast<llvm::Function>(calledOperand)) {
auto name = SymbolRefAttr::get(context, callee->getName());
return builder.create<CallOp>(loc, funcTy, name, operands);
}
return builder.create<CallOp>(loc, funcTy, operands);
}();

// Handle function attributes.
callOp.setCConv(convertCConvFromLLVM(callInst->getCallingConv()));
callOp.setTailCallKind(
convertTailCallKindFromLLVM(callInst->getTailCallKind()));
setFastmathFlagsAttr(inst, callOp);

// Handle function attributes.
if (callInst->hasFnAttr(llvm::Attribute::Convergent))
callOp.setConvergent(true);
if (callInst->hasFnAttr(llvm::Attribute::NoUnwind))
callOp.setNoUnwind(true);
if (callInst->hasFnAttr(llvm::Attribute::WillReturn))
callOp.setWillReturn(true);
callOp.setConvergent(callInst->isConvergent());
callOp.setNoUnwind(callInst->doesNotThrow());
callOp.setWillReturn(callInst->hasFnAttr(llvm::Attribute::WillReturn));

llvm::MemoryEffects memEffects = callInst->getMemoryEffects();
ModRefInfo othermem = convertModRefInfoFromLLVM(
Expand Down
25 changes: 25 additions & 0 deletions mlir/test/Target/LLVMIR/Import/instructions.ll
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,31 @@ define void @varargs_call(i32 %0) {

; // -----

; CHECK: @varargs(...)
declare void @varargs(...)

; CHECK-LABEL: @varargs_call
; CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
define void @varargs_call(i32 %0) {
; CHECK: llvm.call @varargs(%[[ARG1]]) vararg(!llvm.func<void (...)>) : (i32) -> ()
call void @varargs(i32 %0)
ret void
}

; // -----

; CHECK: @varargs(...)
declare void @varargs(...)

; CHECK-LABEL: @empty_varargs_call
define void @empty_varargs_call() {
; CHECK: llvm.call @varargs() vararg(!llvm.func<void (...)>) : () -> ()
call void @varargs()
ret void
}

; // -----

; CHECK: llvm.func @f()
declare void @f()

Expand Down

0 comments on commit 95d993a

Please sign in to comment.