From 95d993a838863269dc1b90de3808c1e40ac6d5f2 Mon Sep 17 00:00:00 2001 From: Henrich Lauko Date: Fri, 24 Jan 2025 20:28:36 +0100 Subject: [PATCH] [MLIR] Fix import of calls with mismatched variadic types (#124286) Previously, an indirect call was incorrectly generated when `llvm::CallBase::getCalledFunction` returned null due to a type mismatch between the call and the function. This patch updates the code to use `llvm::CallBase::getCalledOperand` instead. --- mlir/lib/Target/LLVMIR/ModuleImport.cpp | 77 +++++++++++-------- .../test/Target/LLVMIR/Import/instructions.ll | 25 ++++++ 2 files changed, 70 insertions(+), 32 deletions(-) diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index f6826a2362bfd..40d86efe605ad 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -1495,15 +1495,22 @@ LogicalResult ModuleImport::convertCallTypeAndOperands( if (!callInst->getType()->isVoidTy()) types.push_back(convertType(callInst->getType())); - if (!callInst->getCalledFunction()) { - if (!allowInlineAsm || - !isa(callInst->getCalledOperand())) { - FailureOr called = convertValue(callInst->getCalledOperand()); - if (failed(called)) - return failure(); - operands.push_back(*called); - } + bool isInlineAsm = callInst->isInlineAsm(); + if (isInlineAsm && !allowInlineAsm) + return failure(); + + // Cannot use isIndirectCall() here because we need to handle Constant callees + // that are not considered indirect calls by LLVM. However, in MLIR, they are + // treated as indirect calls to constant operands that need to be converted. + // Skip the callee operand if it's inline assembly, as it's handled separately + // in InlineAsmOp. + if (!isa(callInst->getCalledOperand()) && !isInlineAsm) { + FailureOr called = convertValue(callInst->getCalledOperand()); + if (failed(called)) + return failure(); + operands.push_back(*called); } + SmallVector args(callInst->args()); FailureOr> arguments = convertValues(args); if (failed(arguments)) @@ -1593,7 +1600,8 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) { return success(); } if (inst->getOpcode() == llvm::Instruction::Call) { - auto *callInst = cast(inst); + auto callInst = cast(inst); + llvm::Value *calledOperand = callInst->getCalledOperand(); SmallVector types; SmallVector operands; @@ -1601,15 +1609,12 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) { /*allowInlineAsm=*/true))) return failure(); - auto funcTy = - dyn_cast(convertType(callInst->getFunctionType())); - if (!funcTy) - return failure(); - - if (auto asmI = dyn_cast(callInst->getCalledOperand())) { + if (auto asmI = dyn_cast(calledOperand)) { + Type resultTy = convertType(callInst->getType()); + if (!resultTy) + return failure(); auto callOp = builder.create( - loc, funcTy.getReturnType(), operands, - builder.getStringAttr(asmI->getAsmString()), + loc, resultTy, operands, builder.getStringAttr(asmI->getAsmString()), builder.getStringAttr(asmI->getConstraintString()), /*has_side_effects=*/true, /*is_align_stack=*/false, /*asm_dialect=*/nullptr, @@ -1619,27 +1624,35 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) { else mapNoResultOp(inst, callOp); } else { - CallOp callOp; + auto funcTy = dyn_cast([&]() -> Type { + // Retrieve the real function type. For direct calls, use the callee's + // function type, as it may differ from the operand type in the case of + // variadic functions. For indirect calls, use the call function type. + if (auto callee = dyn_cast(calledOperand)) + return convertType(callee->getFunctionType()); + return convertType(callInst->getFunctionType()); + }()); + + if (!funcTy) + return failure(); - if (llvm::Function *callee = callInst->getCalledFunction()) { - callOp = builder.create( - loc, funcTy, SymbolRefAttr::get(context, callee->getName()), - operands); - } else { - callOp = builder.create(loc, funcTy, operands); - } + auto callOp = [&]() -> CallOp { + if (auto callee = dyn_cast(calledOperand)) { + auto name = SymbolRefAttr::get(context, callee->getName()); + return builder.create(loc, funcTy, name, operands); + } + return builder.create(loc, funcTy, operands); + }(); + + // Handle function attributes. callOp.setCConv(convertCConvFromLLVM(callInst->getCallingConv())); callOp.setTailCallKind( convertTailCallKindFromLLVM(callInst->getTailCallKind())); setFastmathFlagsAttr(inst, callOp); - // Handle function attributes. - if (callInst->hasFnAttr(llvm::Attribute::Convergent)) - callOp.setConvergent(true); - if (callInst->hasFnAttr(llvm::Attribute::NoUnwind)) - callOp.setNoUnwind(true); - if (callInst->hasFnAttr(llvm::Attribute::WillReturn)) - callOp.setWillReturn(true); + callOp.setConvergent(callInst->isConvergent()); + callOp.setNoUnwind(callInst->doesNotThrow()); + callOp.setWillReturn(callInst->hasFnAttr(llvm::Attribute::WillReturn)); llvm::MemoryEffects memEffects = callInst->getMemoryEffects(); ModRefInfo othermem = convertModRefInfoFromLLVM( diff --git a/mlir/test/Target/LLVMIR/Import/instructions.ll b/mlir/test/Target/LLVMIR/Import/instructions.ll index 7377e2584110b..77052ab6e41f6 100644 --- a/mlir/test/Target/LLVMIR/Import/instructions.ll +++ b/mlir/test/Target/LLVMIR/Import/instructions.ll @@ -570,6 +570,31 @@ define void @varargs_call(i32 %0) { ; // ----- +; CHECK: @varargs(...) +declare void @varargs(...) + +; CHECK-LABEL: @varargs_call +; CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +define void @varargs_call(i32 %0) { + ; CHECK: llvm.call @varargs(%[[ARG1]]) vararg(!llvm.func) : (i32) -> () + call void @varargs(i32 %0) + ret void +} + +; // ----- + +; CHECK: @varargs(...) +declare void @varargs(...) + +; CHECK-LABEL: @empty_varargs_call +define void @empty_varargs_call() { + ; CHECK: llvm.call @varargs() vararg(!llvm.func) : () -> () + call void @varargs() + ret void +} + +; // ----- + ; CHECK: llvm.func @f() declare void @f()