Skip to content

Commit

Permalink
Enable LICM for ops with read side effects in scf.for wrapped by a guard
Browse files Browse the repository at this point in the history
  • Loading branch information
ardaunal committed Dec 19, 2024
1 parent bf700c3 commit 0e596fd
Show file tree
Hide file tree
Showing 9 changed files with 326 additions and 18 deletions.
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def ForOp : SCF_Op<"for",
"getLoopInductionVars", "getLoopLowerBounds", "getLoopSteps",
"getLoopUpperBounds", "getYieldedValuesMutable",
"promoteIfSingleIteration", "replaceWithAdditionalYields",
"wrapInTripCountCheck", "unwrapTripCountCheck",
"yieldTiledValuesAndReplace"]>,
AllTypesMatch<["lowerBound", "upperBound", "step"]>,
ConditionallySpeculatable,
Expand Down Expand Up @@ -302,7 +303,7 @@ def ForallOp : SCF_Op<"forall", [
AttrSizedOperandSegments,
AutomaticAllocationScope,
DeclareOpInterfaceMethods<LoopLikeOpInterface,
["getInitsMutable", "getRegionIterArgs", "getLoopInductionVars",
["getInitsMutable", "getRegionIterArgs", "getLoopInductionVars",
"getLoopLowerBounds", "getLoopUpperBounds", "getLoopSteps",
"promoteIfSingleIteration", "yieldTiledValuesAndReplace"]>,
RecursiveMemoryEffects,
Expand Down
20 changes: 20 additions & 0 deletions mlir/include/mlir/Interfaces/LoopLikeInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,26 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
/*methodBody=*/"",
/*defaultImplementation=*/"op->moveBefore($_op);"
>,
InterfaceMethod<[{
Wraps the loop into a trip-count check.
}],
/*retTy=*/"FailureOr<std::pair<::mlir::Operation *, ::mlir::Region *>>",
/*methodName=*/"wrapInTripCountCheck",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/"return ::mlir::failure();"
>,
InterfaceMethod<[{
Unwraps the trip-count check.
}],
/*retTy=*/"::llvm::LogicalResult",
/*methodName=*/"unwrapTripCountCheck",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return ::mlir::failure();
}]
>,
InterfaceMethod<[{
Promotes the loop body to its containing block if the loop is known to
have a single iteration. Returns "success" if the promotion was
Expand Down
4 changes: 4 additions & 0 deletions mlir/include/mlir/Interfaces/SideEffectInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,10 @@ bool wouldOpBeTriviallyDead(Operation *op);
/// conditions are satisfied.
bool isMemoryEffectFree(Operation *op);

/// Returns true if the given operation implements `MemoryEffectOpInterface` and
/// has only read effects.
bool hasOnlyReadEffect(Operation *op);

/// Returns the side effects of an operation. If the operation has
/// RecursiveMemoryEffects, include all side effects of child operations.
///
Expand Down
12 changes: 9 additions & 3 deletions mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,19 @@ class Value;
/// }
/// ```
///
/// Users must supply three callbacks.
/// Users must supply five callbacks.
///
/// - `isDefinedOutsideRegion` returns true if the given value is invariant with
/// respect to the given region. A common implementation might be:
/// `value.getParentRegion()->isProperAncestor(region)`.
/// - `shouldMoveOutOfRegion` returns true if the provided operation can be
/// moved of the given region, e.g. if it is side-effect free.
/// moved of the given region, e.g. if it is side-effect free or has only read
/// side effects.
/// - `wrapInGuard` wraps the given operation in a trip-count check guard.
/// - `moveOutOfRegion` moves the operation out of the given region. A common
/// implementation might be: `op->moveBefore(region->getParentOp())`.
/// - `unwrapGuard` unwraps the trip-count check if there is no op guarded by
/// this check.
///
/// An operation is moved if all of its operands satisfy
/// `isDefinedOutsideRegion` and it satisfies `shouldMoveOutOfRegion`.
Expand All @@ -66,7 +70,9 @@ size_t moveLoopInvariantCode(
ArrayRef<Region *> regions,
function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
function_ref<void(Operation *, Region *)> moveOutOfRegion);
function_ref<FailureOr<std::pair<Operation *, Region *>>()> wrapInGuard,
function_ref<void(Operation *, Region *)> moveOutOfRegion,
function_ref<LogicalResult()> unwrapGuard);

/// Move side-effect free loop invariant code out of a loop-like op using
/// methods provided by the interface.
Expand Down
59 changes: 56 additions & 3 deletions mlir/lib/Dialect/SCF/IR/SCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,60 @@ std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {

std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }

FailureOr<std::pair<Operation *, Region *>> ForOp::wrapInTripCountCheck() {

IRRewriter rewriter(this->getContext());
OpBuilder::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPointAfter(this->getOperation());

auto loc = this->getLoc();
auto cmpIOp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt,
this->getUpperBound(),
this->getLowerBound());
scf::YieldOp yieldInThen;
// Create the trip-count check.
auto ifOp = rewriter.create<scf::IfOp>(
loc, cmpIOp,
[&](OpBuilder &builder, Location loc) {
yieldInThen = builder.create<scf::YieldOp>(loc, this->getResults());
},
[&](OpBuilder &builder, Location loc) {
builder.create<scf::YieldOp>(loc, this->getInitArgs());
});

for (auto [forOpResult, ifOpResult] :
llvm::zip(this->getResults(), ifOp.getResults()))
rewriter.replaceAllUsesExcept(forOpResult, ifOpResult, yieldInThen);
// Move the scf.for into the then block.
rewriter.moveOpBefore(this->getOperation(), yieldInThen);
return std::make_pair(ifOp.getOperation(), &this->getRegion());
}

LogicalResult ForOp::unwrapTripCountCheck() {
auto ifOp = (*this)->getParentRegion()->getParentOp();
if (!isa<scf::IfOp>(ifOp))
return failure();

IRRewriter rewriter(ifOp->getContext());
OpBuilder::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPoint(ifOp);

auto cmpOp = ifOp->getOperand(0).getDefiningOp();
if (!isa<arith::CmpIOp>(cmpOp))
return failure();

auto wrappedForOp = this->getOperation();
rewriter.moveOpBefore(wrappedForOp, ifOp);

for (auto [forOpResult, ifOpResult] :
llvm::zip(wrappedForOp->getResults(), ifOp->getResults()))
rewriter.replaceAllUsesWith(ifOpResult, forOpResult);

rewriter.eraseOp(ifOp);
rewriter.eraseOp(cmpOp);
return success();
}

/// Promotes the loop body of a forOp to its containing block if the forOp
/// it can be determined that the loop has a single iteration.
LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
Expand Down Expand Up @@ -3397,9 +3451,8 @@ ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {

if (functionType.getNumInputs() != operands.size()) {
return parser.emitError(typeLoc)
<< "expected as many input types as operands "
<< "(expected " << operands.size() << " got "
<< functionType.getNumInputs() << ")";
<< "expected as many input types as operands " << "(expected "
<< operands.size() << " got " << functionType.getNumInputs() << ")";
}

// Resolve input operands.
Expand Down
7 changes: 7 additions & 0 deletions mlir/lib/Interfaces/SideEffectInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,13 @@ bool mlir::wouldOpBeTriviallyDead(Operation *op) {
return wouldOpBeTriviallyDeadImpl(op);
}

bool mlir::hasOnlyReadEffect(Operation *op) {
if (auto memEffects = dyn_cast<MemoryEffectOpInterface>(op)) {
return memEffects.onlyHasEffect<MemoryEffects::Read>();
}
return false;
}

bool mlir::isMemoryEffectFree(Operation *op) {
if (auto memInterface = dyn_cast<MemoryEffectOpInterface>(op)) {
if (!memInterface.hasNoEffect())
Expand Down
96 changes: 85 additions & 11 deletions mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,48 +56,117 @@ static bool canBeHoisted(Operation *op,
op, [&](OpOperand &operand) { return definedOutside(operand.get()); });
}

static bool dependsOnGuarded(Operation *op,
function_ref<bool(OpOperand &)> condition) {
auto walkFn = [&](Operation *child) {
for (OpOperand &operand : child->getOpOperands()) {
if (!condition(operand))
return WalkResult::interrupt();
}
return WalkResult::advance();
};
return op->walk(walkFn).wasInterrupted();
}

static bool dependsOnGuarded(Operation *op,
function_ref<bool(Value)> definedOutsideGuard) {
return dependsOnGuarded(op, [&](OpOperand &operand) {
return definedOutsideGuard(operand.get());
});
}

static bool loopSideEffectFreeOrHasOnlyReadEffect(Operation *loop) {
for (Region &region : loop->getRegions()) {
for (Block &block : region.getBlocks()) {
for (Operation &op : block.getOperations()) {
if (!isMemoryEffectFree(&op) && !hasOnlyReadEffect(&op))
return false;
}
}
}
return true;
}

size_t mlir::moveLoopInvariantCode(
ArrayRef<Region *> regions,
function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
function_ref<void(Operation *, Region *)> moveOutOfRegion) {
function_ref<FailureOr<std::pair<Operation *, Region *>>()> wrapInGuard,
function_ref<void(Operation *, Region *)> moveOutOfRegion,
function_ref<LogicalResult()> unwrapGuard) {
size_t numMoved = 0;

for (Region *region : regions) {
LLVM_DEBUG(llvm::dbgs() << "Original loop:\n"
<< *region->getParentOp() << "\n");

auto loopSideEffectFreeOrHasOnlyReadSideEffect =
loopSideEffectFreeOrHasOnlyReadEffect(region->getParentOp());

size_t numMovedWithoutGuard = 0;

FailureOr<std::pair<Operation *, Region *>> ifOpAndRegion = wrapInGuard();
Region *loopRegion = region;
auto isLoopWrapped = false;
if (succeeded(ifOpAndRegion)) {
loopRegion = ifOpAndRegion->second;
isLoopWrapped = true;
}

std::queue<Operation *> worklist;
// Add top-level operations in the loop body to the worklist.
for (Operation &op : region->getOps())
for (Operation &op : loopRegion->getOps())
worklist.push(&op);

auto definedOutside = [&](Value value) {
return isDefinedOutsideRegion(value, region);
return isDefinedOutsideRegion(value, loopRegion);
};

auto definedOutsideGuard = [&](Value value) {
return isDefinedOutsideRegion(value, loopRegion->getParentRegion());
};

while (!worklist.empty()) {
Operation *op = worklist.front();
worklist.pop();
// Skip ops that have already been moved. Check if the op can be hoisted.
if (op->getParentRegion() != region)
if (op->getParentRegion() != loopRegion)
continue;

LLVM_DEBUG(llvm::dbgs() << "Checking op: " << *op << "\n");
if (!shouldMoveOutOfRegion(op, region) ||

if (!shouldMoveOutOfRegion(op, loopRegion) ||
!canBeHoisted(op, definedOutside))
continue;
// Can only hoist pure ops (side-effect free) when there is an op with
// write side effects in the loop.
if (!loopSideEffectFreeOrHasOnlyReadSideEffect && !isMemoryEffectFree(op))
continue;

LLVM_DEBUG(llvm::dbgs() << "Moving loop-invariant op: " << *op << "\n");
moveOutOfRegion(op, region);

auto moveWithoutGuard = isMemoryEffectFree(op) &&
!dependsOnGuarded(op, definedOutsideGuard) &&
isLoopWrapped;
numMovedWithoutGuard += moveWithoutGuard;

moveOutOfRegion(op, moveWithoutGuard ? loopRegion->getParentRegion()
: loopRegion);
++numMoved;

// Since the op has been moved, we need to check its users within the
// top-level of the loop body.
for (Operation *user : op->getUsers())
if (user->getParentRegion() == region)
if (user->getParentRegion() == loopRegion)
worklist.push(user);
}

// Unwrap the loop if it was wrapped but no ops were moved in the guard.
if (isLoopWrapped && numMovedWithoutGuard == numMoved) {
auto tripCountCheckUnwrapped = unwrapGuard();
if (failed(tripCountCheckUnwrapped))
llvm_unreachable("Should not fail unwrapping trip-count check");
}
}

return numMoved;
Expand All @@ -106,13 +175,18 @@ size_t mlir::moveLoopInvariantCode(
size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) {
return moveLoopInvariantCode(
loopLike.getLoopRegions(),
[&](Value value, Region *) {
return loopLike.isDefinedOutsideOfLoop(value);
[&](Value value, Region *region) {
return !region->isAncestor(value.getParentRegion());
},
[&](Operation *op, Region *) {
return isMemoryEffectFree(op) && isSpeculatable(op);
return isSpeculatable(op) &&
(isMemoryEffectFree(op) || hasOnlyReadEffect(op));
},
[&]() { return loopLike.wrapInTripCountCheck(); },
[&](Operation *op, Region *region) {
op->moveBefore(region->getParentOp());
},
[&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); });
[&]() { return loopLike.unwrapTripCountCheck(); });
}

namespace {
Expand Down
Loading

0 comments on commit 0e596fd

Please sign in to comment.