diff --git a/clang/lib/CIR/CodeGen/CIRGenCall.cpp b/clang/lib/CIR/CodeGen/CIRGenCall.cpp index def1061ae45e..d9345627c3fe 100644 --- a/clang/lib/CIR/CodeGen/CIRGenCall.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenCall.cpp @@ -551,6 +551,18 @@ static cir::CIRCallOpInterface emitCallLikeOp( extraFnAttrs); } +static RValue getRValueThroughMemory(mlir::Location loc, + CIRGenBuilderTy &builder, + mlir::Value val, + Address addr) { + auto ip = builder.saveInsertionPoint(); + builder.setInsertionPointAfterValue(val); + builder.createStore(loc, val, addr); + builder.restoreInsertionPoint(ip); + auto load = builder.createLoad(loc, addr); + return RValue::get(load); +} + RValue CIRGenFunction::emitCall(const CIRGenFunctionInfo &CallInfo, const CIRGenCallee &Callee, ReturnValueSlot ReturnValue, @@ -890,19 +902,14 @@ RValue CIRGenFunction::emitCall(const CIRGenFunctionInfo &CallInfo, assert(Results.size() <= 1 && "multiple returns NYI"); assert(Results[0].getType() == RetCIRTy && "Bitcast support NYI"); - auto reg = builder.getBlock()->getParent(); - if (reg != theCall->getParentRegion()) { + auto region = builder.getBlock()->getParent(); + if (region != theCall->getParentRegion()) { Address DestPtr = ReturnValue.getValue(); if (!DestPtr.isValid()) DestPtr = CreateMemTemp(RetTy, callLoc, "tmp"); - auto ip = builder.saveInsertionPoint(); - builder.setInsertionPointAfter(theCall); - builder.createStore(callLoc, Results[0], DestPtr); - builder.restoreInsertionPoint(ip); - auto load = builder.createLoad(callLoc, DestPtr); - return RValue::get(load); + return getRValueThroughMemory(callLoc, builder, Results[0], DestPtr); } return RValue::get(Results[0]); diff --git a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp index 95b5b008df1b..c9bd0fa3d973 100644 --- a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp +++ b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp @@ -19,6 +19,8 @@ #include "clang/CIR/Dialect/IR/CIRDialect.h" #include "clang/CIR/Dialect/Passes.h" +#include + using namespace mlir; using namespace cir; @@ -910,6 +912,42 @@ void populateFlattenCFGPatterns(RewritePatternSet &patterns) { patterns.getContext()); } +void removeTempAllocas(DominanceInfo& dom, FuncOp fun) { + + fun.walk([&](AllocaOp op) { + if (op.getName().str().find("tmp") == std::string::npos) + return; + + StoreOp store; + LoadOp load; + int total = 0; + + for (auto* u : op->getUsers()) { + total++; + if (auto ld = dyn_cast(u)) + load = ld; + if (auto st = dyn_cast(u)) + if (st.getAddr() == op.getResult()) + store = st; + } + + if (total == 2 && load && store && dom.dominates(store, load)) { + if (load->hasOneUse()) { + if (auto st = dyn_cast(*load->user_begin())) { + if (auto al = dyn_cast(st.getAddr().getDefiningOp())) { + llvm::SmallVector vals; + vals.push_back(al.getResult()); + op->replaceAllUsesWith(vals); + op->erase(); + } + } + } + } + + }); + +} + void FlattenCFGPass::runOnOperation() { RewritePatternSet patterns(&getContext()); populateFlattenCFGPatterns(patterns); @@ -924,6 +962,12 @@ void FlattenCFGPass::runOnOperation() { // Apply patterns. if (applyOpPatternsAndFold(ops, std::move(patterns)).failed()) signalPassFailure(); + + auto &dom = getAnalysis(); + + getOperation()->walk([&](FuncOp fun) { + removeTempAllocas(dom, fun); + }); } } // namespace