-
Notifications
You must be signed in to change notification settings - Fork 12.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir] Enable LICM for ops with only read side effects in scf.for #120302
base: main
Are you sure you want to change the base?
[mlir] Enable LICM for ops with only read side effects in scf.for #120302
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-core Author: Arda Unal (ardaunal) ChangesEnable ops with only read side effects in scf.for to be hoisted with a scf.if guard that checks against the trip count This patch takes a step towards a less conservative LICM in MLIR as discussed in the following discourse thread: This patch in particular does the following:
Patch is 22.40 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/120302.diff 9 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 23c597a1ca5108..b54df8e3ef313d 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -139,6 +139,7 @@ def ForOp : SCF_Op<"for",
"getLoopInductionVars", "getLoopLowerBounds", "getLoopSteps",
"getLoopUpperBounds", "getYieldedValuesMutable",
"promoteIfSingleIteration", "replaceWithAdditionalYields",
+ "wrapInTripCountCheck", "unwrapTripCountCheck",
"yieldTiledValuesAndReplace"]>,
AllTypesMatch<["lowerBound", "upperBound", "step"]>,
ConditionallySpeculatable,
@@ -302,7 +303,7 @@ def ForallOp : SCF_Op<"forall", [
AttrSizedOperandSegments,
AutomaticAllocationScope,
DeclareOpInterfaceMethods<LoopLikeOpInterface,
- ["getInitsMutable", "getRegionIterArgs", "getLoopInductionVars",
+ ["getInitsMutable", "getRegionIterArgs", "getLoopInductionVars",
"getLoopLowerBounds", "getLoopUpperBounds", "getLoopSteps",
"promoteIfSingleIteration", "yieldTiledValuesAndReplace"]>,
RecursiveMemoryEffects,
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index c6bffe347419e5..831830130b0ddc 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -79,6 +79,26 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
/*methodBody=*/"",
/*defaultImplementation=*/"op->moveBefore($_op);"
>,
+ InterfaceMethod<[{
+ Wraps the loop into a trip-count check.
+ }],
+ /*retTy=*/"FailureOr<std::pair<::mlir::Operation *, ::mlir::Region *>>",
+ /*methodName=*/"wrapInTripCountCheck",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/"return ::mlir::failure();"
+ >,
+ InterfaceMethod<[{
+ Unwraps the trip-count check.
+ }],
+ /*retTy=*/"::llvm::LogicalResult",
+ /*methodName=*/"unwrapTripCountCheck",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return ::mlir::failure();
+ }]
+ >,
InterfaceMethod<[{
Promotes the loop body to its containing block if the loop is known to
have a single iteration. Returns "success" if the promotion was
diff --git a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
index aef7ec622fe4f8..1a7f66e2234949 100644
--- a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
+++ b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
@@ -433,6 +433,10 @@ bool wouldOpBeTriviallyDead(Operation *op);
/// conditions are satisfied.
bool isMemoryEffectFree(Operation *op);
+/// Returns true if the given operation implements `MemoryEffectOpInterface` and
+/// has only read effects.
+bool hasOnlyReadEffect(Operation *op);
+
/// Returns the side effects of an operation. If the operation has
/// RecursiveMemoryEffects, include all side effects of child operations.
///
diff --git a/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h b/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
index 3ceef44d799e89..ae6719abe79c00 100644
--- a/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
+++ b/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
@@ -48,15 +48,19 @@ class Value;
/// }
/// ```
///
-/// Users must supply three callbacks.
+/// Users must supply five callbacks.
///
/// - `isDefinedOutsideRegion` returns true if the given value is invariant with
/// respect to the given region. A common implementation might be:
/// `value.getParentRegion()->isProperAncestor(region)`.
/// - `shouldMoveOutOfRegion` returns true if the provided operation can be
-/// moved of the given region, e.g. if it is side-effect free.
+/// moved of the given region, e.g. if it is side-effect free or has only read
+/// side effects.
+/// - `wrapInGuard` wraps the given operation in a trip-count check guard.
/// - `moveOutOfRegion` moves the operation out of the given region. A common
/// implementation might be: `op->moveBefore(region->getParentOp())`.
+/// - `unwrapGuard` unwraps the trip-count check if there is no op guarded by
+/// this check.
///
/// An operation is moved if all of its operands satisfy
/// `isDefinedOutsideRegion` and it satisfies `shouldMoveOutOfRegion`.
@@ -66,7 +70,9 @@ size_t moveLoopInvariantCode(
ArrayRef<Region *> regions,
function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
- function_ref<void(Operation *, Region *)> moveOutOfRegion);
+ function_ref<FailureOr<std::pair<Operation *, Region *>>()> wrapInGuard,
+ function_ref<void(Operation *, Region *)> moveOutOfRegion,
+ function_ref<LogicalResult()> unwrapGuard);
/// Move side-effect free loop invariant code out of a loop-like op using
/// methods provided by the interface.
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index eded1c394f126c..148617c84547c7 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -395,6 +395,83 @@ std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
+FailureOr<std::pair<Operation *, Region *>> ForOp::wrapInTripCountCheck() {
+ auto lowerBound = this->getLowerBound();
+ auto upperBound = this->getUpperBound();
+ auto step = this->getStep();
+ auto initArgs = this->getInitArgs();
+ auto results = this->getResults();
+ auto loc = this->getLoc();
+
+ IRRewriter rewriter(this->getContext());
+ OpBuilder::InsertionGuard insertGuard(rewriter);
+ rewriter.setInsertionPointAfter(this->getOperation());
+
+ // Form the trip count calculation
+ auto subOp = rewriter.create<arith::SubIOp>(loc, upperBound, lowerBound);
+ auto ceilDivSIOp = rewriter.create<arith::CeilDivSIOp>(loc, subOp, step);
+ Value zero;
+ if (upperBound.getType().isIndex()) {
+ zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ } else {
+ zero = rewriter.create<arith::ConstantIntOp>(
+ loc, 0,
+ /*width=*/
+ upperBound.getType().getIntOrFloatBitWidth());
+ }
+ auto cmpIOp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
+ ceilDivSIOp, zero);
+ scf::YieldOp yieldInThen;
+ // Create the trip-count check
+ auto ifOp = rewriter.create<scf::IfOp>(
+ loc, cmpIOp,
+ [&](OpBuilder &builder, Location loc) {
+ yieldInThen = builder.create<scf::YieldOp>(loc, results);
+ },
+ [&](OpBuilder &builder, Location loc) {
+ builder.create<scf::YieldOp>(loc, initArgs);
+ });
+
+ for (auto [forOpResult, ifOpResult] : llvm::zip(results, ifOp.getResults()))
+ rewriter.replaceAllUsesExcept(forOpResult, ifOpResult, yieldInThen);
+ // Move the scf.for into the then block
+ rewriter.moveOpBefore(this->getOperation(), yieldInThen);
+ return std::make_pair(ifOp.getOperation(), &this->getRegion());
+}
+
+LogicalResult ForOp::unwrapTripCountCheck() {
+ auto ifOp = (*this)->getParentRegion()->getParentOp();
+ if (!isa<scf::IfOp>(ifOp))
+ return failure();
+
+ auto wrappedForOp = this->getOperation();
+
+ IRRewriter rewriter(ifOp->getContext());
+ OpBuilder::InsertionGuard insertGuard(rewriter);
+ rewriter.setInsertionPoint(ifOp);
+
+ auto cmpOp = ifOp->getOperand(0).getDefiningOp();
+ auto ceilDivSIOp = cmpOp->getOperand(0).getDefiningOp();
+ auto zero = cmpOp->getOperand(1).getDefiningOp();
+ auto subOp = ceilDivSIOp->getOperand(0).getDefiningOp();
+ if (!isa<arith::CmpIOp>(cmpOp) || !isa<arith::CeilDivSIOp>(ceilDivSIOp) ||
+ !isa<arith::SubIOp>(subOp))
+ return failure();
+
+ rewriter.moveOpBefore(wrappedForOp, ifOp);
+
+ for (auto [forOpResult, ifOpResult] :
+ llvm::zip(wrappedForOp->getResults(), ifOp->getResults()))
+ rewriter.replaceAllUsesWith(ifOpResult, forOpResult);
+
+ rewriter.eraseOp(ifOp);
+ rewriter.eraseOp(cmpOp);
+ rewriter.eraseOp(zero);
+ rewriter.eraseOp(ceilDivSIOp);
+ rewriter.eraseOp(subOp);
+ return success();
+}
+
/// Promotes the loop body of a forOp to its containing block if the forOp
/// it can be determined that the loop has a single iteration.
LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
@@ -3397,9 +3474,8 @@ ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
if (functionType.getNumInputs() != operands.size()) {
return parser.emitError(typeLoc)
- << "expected as many input types as operands "
- << "(expected " << operands.size() << " got "
- << functionType.getNumInputs() << ")";
+ << "expected as many input types as operands " << "(expected "
+ << operands.size() << " got " << functionType.getNumInputs() << ")";
}
// Resolve input operands.
diff --git a/mlir/lib/Interfaces/SideEffectInterfaces.cpp b/mlir/lib/Interfaces/SideEffectInterfaces.cpp
index c9feb001a19844..f45d5f3d227407 100644
--- a/mlir/lib/Interfaces/SideEffectInterfaces.cpp
+++ b/mlir/lib/Interfaces/SideEffectInterfaces.cpp
@@ -306,6 +306,13 @@ bool mlir::wouldOpBeTriviallyDead(Operation *op) {
return wouldOpBeTriviallyDeadImpl(op);
}
+bool mlir::hasOnlyReadEffect(Operation *op) {
+ if (auto memEffects = dyn_cast<MemoryEffectOpInterface>(op)) {
+ return memEffects.onlyHasEffect<MemoryEffects::Read>();
+ }
+ return false;
+}
+
bool mlir::isMemoryEffectFree(Operation *op) {
if (auto memInterface = dyn_cast<MemoryEffectOpInterface>(op)) {
if (!memInterface.hasNoEffect())
diff --git a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
index 7460746934a78c..1bdc74dc2a170a 100644
--- a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
@@ -56,48 +56,117 @@ static bool canBeHoisted(Operation *op,
op, [&](OpOperand &operand) { return definedOutside(operand.get()); });
}
+static bool dependsOnGuarded(Operation *op,
+ function_ref<bool(OpOperand &)> condition) {
+ auto walkFn = [&](Operation *child) {
+ for (OpOperand &operand : child->getOpOperands()) {
+ if (!condition(operand))
+ return WalkResult::interrupt();
+ }
+ return WalkResult::advance();
+ };
+ return op->walk(walkFn).wasInterrupted();
+}
+
+static bool dependsOnGuarded(Operation *op,
+ function_ref<bool(Value)> definedOutsideGuard) {
+ return dependsOnGuarded(op, [&](OpOperand &operand) {
+ return definedOutsideGuard(operand.get());
+ });
+}
+
+static bool loopSideEffectFreeOrHasOnlyReadEffect(Operation *loop) {
+ for (auto ®ion : loop->getRegions()) {
+ for (auto &block : region.getBlocks()) {
+ for (Operation &op : block.getOperations()) {
+ if (!isMemoryEffectFree(&op) && !hasOnlyReadEffect(&op))
+ return false;
+ }
+ }
+ }
+ return true;
+}
+
size_t mlir::moveLoopInvariantCode(
ArrayRef<Region *> regions,
function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
- function_ref<void(Operation *, Region *)> moveOutOfRegion) {
+ function_ref<FailureOr<std::pair<Operation *, Region *>>()> wrapInGuard,
+ function_ref<void(Operation *, Region *)> moveOutOfRegion,
+ function_ref<LogicalResult()> unwrapGuard) {
size_t numMoved = 0;
for (Region *region : regions) {
LLVM_DEBUG(llvm::dbgs() << "Original loop:\n"
<< *region->getParentOp() << "\n");
+ auto loopSideEffectFreeOrHasOnlyReadSideEffect =
+ loopSideEffectFreeOrHasOnlyReadEffect(region->getParentOp());
+
+ size_t numMovedWithoutGuard = 0;
+
+ FailureOr<std::pair<Operation *, Region *>> ifOpAndRegion = wrapInGuard();
+ Region *loopRegion = region;
+ auto isLoopWrapped = false;
+ if (succeeded(ifOpAndRegion)) {
+ loopRegion = ifOpAndRegion->second;
+ isLoopWrapped = true;
+ }
+
std::queue<Operation *> worklist;
// Add top-level operations in the loop body to the worklist.
- for (Operation &op : region->getOps())
+ for (Operation &op : loopRegion->getOps())
worklist.push(&op);
auto definedOutside = [&](Value value) {
- return isDefinedOutsideRegion(value, region);
+ return isDefinedOutsideRegion(value, loopRegion);
+ };
+
+ auto definedOutsideGuard = [&](Value value) {
+ return isDefinedOutsideRegion(value, loopRegion->getParentRegion());
};
while (!worklist.empty()) {
Operation *op = worklist.front();
worklist.pop();
// Skip ops that have already been moved. Check if the op can be hoisted.
- if (op->getParentRegion() != region)
+ if (op->getParentRegion() != loopRegion)
continue;
LLVM_DEBUG(llvm::dbgs() << "Checking op: " << *op << "\n");
- if (!shouldMoveOutOfRegion(op, region) ||
+
+ if (!shouldMoveOutOfRegion(op, loopRegion) ||
!canBeHoisted(op, definedOutside))
continue;
+ // Can only hoist pure ops (side-effect free) when there is an op with
+ // write side effects in the loop
+ if (!loopSideEffectFreeOrHasOnlyReadSideEffect && !isMemoryEffectFree(op))
+ continue;
LLVM_DEBUG(llvm::dbgs() << "Moving loop-invariant op: " << *op << "\n");
- moveOutOfRegion(op, region);
+
+ auto moveWithoutGuard = isMemoryEffectFree(op) &&
+ !dependsOnGuarded(op, definedOutsideGuard) &&
+ isLoopWrapped;
+ numMovedWithoutGuard += moveWithoutGuard;
+
+ moveOutOfRegion(op, moveWithoutGuard ? loopRegion->getParentRegion()
+ : loopRegion);
++numMoved;
// Since the op has been moved, we need to check its users within the
// top-level of the loop body.
for (Operation *user : op->getUsers())
- if (user->getParentRegion() == region)
+ if (user->getParentRegion() == loopRegion)
worklist.push(user);
}
+
+ // Unwrap the loop if it was wrapped but no ops were moved in the guard.
+ if (isLoopWrapped && numMovedWithoutGuard == numMoved) {
+ auto tripCountCheckUnwrapped = unwrapGuard();
+ if (failed(tripCountCheckUnwrapped))
+ llvm_unreachable("Should not fail unwrapping trip-count check");
+ }
}
return numMoved;
@@ -106,13 +175,18 @@ size_t mlir::moveLoopInvariantCode(
size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) {
return moveLoopInvariantCode(
loopLike.getLoopRegions(),
- [&](Value value, Region *) {
- return loopLike.isDefinedOutsideOfLoop(value);
+ [&](Value value, Region *region) {
+ return !region->isAncestor(value.getParentRegion());
},
[&](Operation *op, Region *) {
- return isMemoryEffectFree(op) && isSpeculatable(op);
+ return isSpeculatable(op) &&
+ (isMemoryEffectFree(op) || hasOnlyReadEffect(op));
+ },
+ [&]() { return loopLike.wrapInTripCountCheck(); },
+ [&](Operation *op, Region *region) {
+ op->moveBefore(region->getParentOp());
},
- [&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); });
+ [&]() { return loopLike.unwrapTripCountCheck(); });
}
namespace {
diff --git a/mlir/test/Transforms/loop-invariant-code-motion.mlir b/mlir/test/Transforms/loop-invariant-code-motion.mlir
index e4c423ce7052bf..6f5cc60c59252c 100644
--- a/mlir/test/Transforms/loop-invariant-code-motion.mlir
+++ b/mlir/test/Transforms/loop-invariant-code-motion.mlir
@@ -593,6 +593,111 @@ func.func @test_recursively_speculatable_op_failure(%lb: index, %ub: index, %ste
return
}
+// CHECK-LABEL: test_speculatable_op_with_read_side_effect_success
+func.func @test_speculatable_op_with_read_side_effect_success(%lb: index, %ub: index, %step: index) -> i32 {
+ // CHECK: arith.subi
+ // CHECK: arith.ceildivsi
+ // CHECK: arith.constant
+ // CHECK: arith.cmpi
+ // CHECK-NEXT: test.always_speculatable_op
+ // CHECK-NEXT: scf.if
+ // CHECK-NEXT: test.speculatable_op_with_memread
+ // CHECK-NEXT: scf.for
+ // CHECK-NOT: test.always_speculatable_op
+ // CHECK-NOT: test.speculatable_op_with_memread
+ %cst_0 = arith.constant 0 : i32
+ %cst_42 = arith.constant dense<42> : tensor<64xi32>
+ %ind_42 = arith.constant 42 : index
+ %sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
+ %always = "test.always_speculatable_op"() : () -> i32
+ %always_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
+ %i_cast = arith.index_cast %i: index to i32
+ %i_sum = arith.addi %acc, %i_cast : i32
+ %test_sum = arith.addi %i_sum, %always_read : i32
+ scf.yield %test_sum : i32
+ }
+ return %sum_result : i32
+}
+
+// CHECK-LABEL: test_speculatable_op_with_read_side_effect_success_with_dependents
+func.func @test_speculatable_op_with_read_side_effect_success_with_dependents(%lb: index, %ub: index, %step: index) -> i32 {
+ // CHECK: arith.subi
+ // CHECK: arith.ceildivsi
+ // CHECK: arith.constant
+ // CHECK: arith.cmpi
+ // CHECK-NEXT: test.always_speculatable_op
+ // CHECK-NEXT: scf.if
+ // CHECK-NEXT: test.speculatable_op_with_memread
+ // CHECK-NEXT: arith.addi
+ // CHECK-NEXT: scf.for
+ // CHECK-NOT: test.always_speculatable_op
+ // CHECK-NOT: test.speculatable_op_with_memread
+ %cst_0 = arith.constant 0 : i32
+ %cst_42 = arith.constant dense<42> : tensor<64xi32>
+ %ind_42 = arith.constant 42 : index
+ %sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
+ %always = "test.always_speculatable_op"() : () -> i32
+ %always_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
+ %add = arith.addi %always_read, %cst_0 : i32
+ %i_cast = arith.index_cast %i: index to i32
+ %i_sum = arith.addi %acc, %i_cast : i32
+ %test_sum = arith.addi %i_sum, %add : i32
+ scf.yield %test_sum : i32
+ }
+ return %sum_result : i32
+}
+
+// CHECK-LABEL: test_speculatable_op_with_read_side_effect_failure_due_to_write
+func.func @test_speculatable_op_with_read_side_effect_failure_due_to_write(%lb: index, %ub: index, %step: index) -> i32 {
+ // CHECK: test.always_speculatable_op
+ // CHECK-NEXT: scf.for
+ // CHECK-NOT: test.always_speculatable_op
+ // CHECK: test.speculatable_op_with_memread
+ // CHECK: test.speculatable_op_with_memwrite
+ %cst_0 = arith.constant 0 : i32
+ %cst_42 = arith.constant dense<42> : tensor<64xi32>
+ %ind_42 = arith.constant 42 : index
+ %sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
+ %always = "test.always_speculatable_op"() : () -> i32
+ %always_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
+ %i_cast = arith.index_cast %i: index to i32
+ %i_sum = arith.addi %acc, %i_cast : i32
+ %test_sum = arith.addi %i_sum, %always_read : i32
+ %always_write = "test.speculatable_op_with_memwrite"(%cst_42) : (tensor<64xi32>) -> i32
+ scf.yield %test_sum : i32
+ }
+ return %sum_result : i32
+}
+
+// CHECK-LABEL: test_speculatable_op_with_read_side_effect_failure_due_to_nested_write
+func.func @test_speculatable_op_with_read_side_effect_failure_due_to_nested_write(%lb: index, %ub: index, %step: index) -> i32 {
+ // CHECK: test.always_speculatable_op
+ // CHECK-NEXT: scf.for
+ // CHECK-NOT: test.always_speculatable_op
+ // CHECK: test.speculatable_op_with_memread
+ // CHECK: scf.for
+ // CHECK: scf.if
+ // CHECK: test.speculatable_op_with_memwrite
+ %cst_0 = arith.constant 0 : i32
+ %cst_42 = arith.constant dense<42> : tensor<64xi32>
+ %ind_42 = arith.constant 42 : index
+ %sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
+ %always = "test.always_speculatable_op"() : () -> i32
+ %always_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
+ %i_cast = arith.index_cast %i: index to i32
+ %i_sum = arith....
[truncated]
|
@llvm/pr-subscribers-mlir-scf Author: Arda Unal (ardaunal) ChangesEnable ops with only read side effects in scf.for to be hoisted with a scf.if guard that checks against the trip count This patch takes a step towards a less conservative LICM in MLIR as discussed in the following discourse thread: This patch in particular does the following:
Patch is 22.40 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/120302.diff 9 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 23c597a1ca5108..b54df8e3ef313d 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -139,6 +139,7 @@ def ForOp : SCF_Op<"for",
"getLoopInductionVars", "getLoopLowerBounds", "getLoopSteps",
"getLoopUpperBounds", "getYieldedValuesMutable",
"promoteIfSingleIteration", "replaceWithAdditionalYields",
+ "wrapInTripCountCheck", "unwrapTripCountCheck",
"yieldTiledValuesAndReplace"]>,
AllTypesMatch<["lowerBound", "upperBound", "step"]>,
ConditionallySpeculatable,
@@ -302,7 +303,7 @@ def ForallOp : SCF_Op<"forall", [
AttrSizedOperandSegments,
AutomaticAllocationScope,
DeclareOpInterfaceMethods<LoopLikeOpInterface,
- ["getInitsMutable", "getRegionIterArgs", "getLoopInductionVars",
+ ["getInitsMutable", "getRegionIterArgs", "getLoopInductionVars",
"getLoopLowerBounds", "getLoopUpperBounds", "getLoopSteps",
"promoteIfSingleIteration", "yieldTiledValuesAndReplace"]>,
RecursiveMemoryEffects,
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index c6bffe347419e5..831830130b0ddc 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -79,6 +79,26 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
/*methodBody=*/"",
/*defaultImplementation=*/"op->moveBefore($_op);"
>,
+ InterfaceMethod<[{
+ Wraps the loop into a trip-count check.
+ }],
+ /*retTy=*/"FailureOr<std::pair<::mlir::Operation *, ::mlir::Region *>>",
+ /*methodName=*/"wrapInTripCountCheck",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/"return ::mlir::failure();"
+ >,
+ InterfaceMethod<[{
+ Unwraps the trip-count check.
+ }],
+ /*retTy=*/"::llvm::LogicalResult",
+ /*methodName=*/"unwrapTripCountCheck",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return ::mlir::failure();
+ }]
+ >,
InterfaceMethod<[{
Promotes the loop body to its containing block if the loop is known to
have a single iteration. Returns "success" if the promotion was
diff --git a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
index aef7ec622fe4f8..1a7f66e2234949 100644
--- a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
+++ b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
@@ -433,6 +433,10 @@ bool wouldOpBeTriviallyDead(Operation *op);
/// conditions are satisfied.
bool isMemoryEffectFree(Operation *op);
+/// Returns true if the given operation implements `MemoryEffectOpInterface` and
+/// has only read effects.
+bool hasOnlyReadEffect(Operation *op);
+
/// Returns the side effects of an operation. If the operation has
/// RecursiveMemoryEffects, include all side effects of child operations.
///
diff --git a/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h b/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
index 3ceef44d799e89..ae6719abe79c00 100644
--- a/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
+++ b/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
@@ -48,15 +48,19 @@ class Value;
/// }
/// ```
///
-/// Users must supply three callbacks.
+/// Users must supply five callbacks.
///
/// - `isDefinedOutsideRegion` returns true if the given value is invariant with
/// respect to the given region. A common implementation might be:
/// `value.getParentRegion()->isProperAncestor(region)`.
/// - `shouldMoveOutOfRegion` returns true if the provided operation can be
-/// moved of the given region, e.g. if it is side-effect free.
+/// moved of the given region, e.g. if it is side-effect free or has only read
+/// side effects.
+/// - `wrapInGuard` wraps the given operation in a trip-count check guard.
/// - `moveOutOfRegion` moves the operation out of the given region. A common
/// implementation might be: `op->moveBefore(region->getParentOp())`.
+/// - `unwrapGuard` unwraps the trip-count check if there is no op guarded by
+/// this check.
///
/// An operation is moved if all of its operands satisfy
/// `isDefinedOutsideRegion` and it satisfies `shouldMoveOutOfRegion`.
@@ -66,7 +70,9 @@ size_t moveLoopInvariantCode(
ArrayRef<Region *> regions,
function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
- function_ref<void(Operation *, Region *)> moveOutOfRegion);
+ function_ref<FailureOr<std::pair<Operation *, Region *>>()> wrapInGuard,
+ function_ref<void(Operation *, Region *)> moveOutOfRegion,
+ function_ref<LogicalResult()> unwrapGuard);
/// Move side-effect free loop invariant code out of a loop-like op using
/// methods provided by the interface.
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index eded1c394f126c..148617c84547c7 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -395,6 +395,83 @@ std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
+FailureOr<std::pair<Operation *, Region *>> ForOp::wrapInTripCountCheck() {
+ auto lowerBound = this->getLowerBound();
+ auto upperBound = this->getUpperBound();
+ auto step = this->getStep();
+ auto initArgs = this->getInitArgs();
+ auto results = this->getResults();
+ auto loc = this->getLoc();
+
+ IRRewriter rewriter(this->getContext());
+ OpBuilder::InsertionGuard insertGuard(rewriter);
+ rewriter.setInsertionPointAfter(this->getOperation());
+
+ // Form the trip count calculation
+ auto subOp = rewriter.create<arith::SubIOp>(loc, upperBound, lowerBound);
+ auto ceilDivSIOp = rewriter.create<arith::CeilDivSIOp>(loc, subOp, step);
+ Value zero;
+ if (upperBound.getType().isIndex()) {
+ zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ } else {
+ zero = rewriter.create<arith::ConstantIntOp>(
+ loc, 0,
+ /*width=*/
+ upperBound.getType().getIntOrFloatBitWidth());
+ }
+ auto cmpIOp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
+ ceilDivSIOp, zero);
+ scf::YieldOp yieldInThen;
+ // Create the trip-count check
+ auto ifOp = rewriter.create<scf::IfOp>(
+ loc, cmpIOp,
+ [&](OpBuilder &builder, Location loc) {
+ yieldInThen = builder.create<scf::YieldOp>(loc, results);
+ },
+ [&](OpBuilder &builder, Location loc) {
+ builder.create<scf::YieldOp>(loc, initArgs);
+ });
+
+ for (auto [forOpResult, ifOpResult] : llvm::zip(results, ifOp.getResults()))
+ rewriter.replaceAllUsesExcept(forOpResult, ifOpResult, yieldInThen);
+ // Move the scf.for into the then block
+ rewriter.moveOpBefore(this->getOperation(), yieldInThen);
+ return std::make_pair(ifOp.getOperation(), &this->getRegion());
+}
+
+LogicalResult ForOp::unwrapTripCountCheck() {
+ auto ifOp = (*this)->getParentRegion()->getParentOp();
+ if (!isa<scf::IfOp>(ifOp))
+ return failure();
+
+ auto wrappedForOp = this->getOperation();
+
+ IRRewriter rewriter(ifOp->getContext());
+ OpBuilder::InsertionGuard insertGuard(rewriter);
+ rewriter.setInsertionPoint(ifOp);
+
+ auto cmpOp = ifOp->getOperand(0).getDefiningOp();
+ auto ceilDivSIOp = cmpOp->getOperand(0).getDefiningOp();
+ auto zero = cmpOp->getOperand(1).getDefiningOp();
+ auto subOp = ceilDivSIOp->getOperand(0).getDefiningOp();
+ if (!isa<arith::CmpIOp>(cmpOp) || !isa<arith::CeilDivSIOp>(ceilDivSIOp) ||
+ !isa<arith::SubIOp>(subOp))
+ return failure();
+
+ rewriter.moveOpBefore(wrappedForOp, ifOp);
+
+ for (auto [forOpResult, ifOpResult] :
+ llvm::zip(wrappedForOp->getResults(), ifOp->getResults()))
+ rewriter.replaceAllUsesWith(ifOpResult, forOpResult);
+
+ rewriter.eraseOp(ifOp);
+ rewriter.eraseOp(cmpOp);
+ rewriter.eraseOp(zero);
+ rewriter.eraseOp(ceilDivSIOp);
+ rewriter.eraseOp(subOp);
+ return success();
+}
+
/// Promotes the loop body of a forOp to its containing block if the forOp
/// it can be determined that the loop has a single iteration.
LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
@@ -3397,9 +3474,8 @@ ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
if (functionType.getNumInputs() != operands.size()) {
return parser.emitError(typeLoc)
- << "expected as many input types as operands "
- << "(expected " << operands.size() << " got "
- << functionType.getNumInputs() << ")";
+ << "expected as many input types as operands " << "(expected "
+ << operands.size() << " got " << functionType.getNumInputs() << ")";
}
// Resolve input operands.
diff --git a/mlir/lib/Interfaces/SideEffectInterfaces.cpp b/mlir/lib/Interfaces/SideEffectInterfaces.cpp
index c9feb001a19844..f45d5f3d227407 100644
--- a/mlir/lib/Interfaces/SideEffectInterfaces.cpp
+++ b/mlir/lib/Interfaces/SideEffectInterfaces.cpp
@@ -306,6 +306,13 @@ bool mlir::wouldOpBeTriviallyDead(Operation *op) {
return wouldOpBeTriviallyDeadImpl(op);
}
+bool mlir::hasOnlyReadEffect(Operation *op) {
+ if (auto memEffects = dyn_cast<MemoryEffectOpInterface>(op)) {
+ return memEffects.onlyHasEffect<MemoryEffects::Read>();
+ }
+ return false;
+}
+
bool mlir::isMemoryEffectFree(Operation *op) {
if (auto memInterface = dyn_cast<MemoryEffectOpInterface>(op)) {
if (!memInterface.hasNoEffect())
diff --git a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
index 7460746934a78c..1bdc74dc2a170a 100644
--- a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
@@ -56,48 +56,117 @@ static bool canBeHoisted(Operation *op,
op, [&](OpOperand &operand) { return definedOutside(operand.get()); });
}
+static bool dependsOnGuarded(Operation *op,
+ function_ref<bool(OpOperand &)> condition) {
+ auto walkFn = [&](Operation *child) {
+ for (OpOperand &operand : child->getOpOperands()) {
+ if (!condition(operand))
+ return WalkResult::interrupt();
+ }
+ return WalkResult::advance();
+ };
+ return op->walk(walkFn).wasInterrupted();
+}
+
+static bool dependsOnGuarded(Operation *op,
+ function_ref<bool(Value)> definedOutsideGuard) {
+ return dependsOnGuarded(op, [&](OpOperand &operand) {
+ return definedOutsideGuard(operand.get());
+ });
+}
+
+static bool loopSideEffectFreeOrHasOnlyReadEffect(Operation *loop) {
+ for (auto ®ion : loop->getRegions()) {
+ for (auto &block : region.getBlocks()) {
+ for (Operation &op : block.getOperations()) {
+ if (!isMemoryEffectFree(&op) && !hasOnlyReadEffect(&op))
+ return false;
+ }
+ }
+ }
+ return true;
+}
+
size_t mlir::moveLoopInvariantCode(
ArrayRef<Region *> regions,
function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
- function_ref<void(Operation *, Region *)> moveOutOfRegion) {
+ function_ref<FailureOr<std::pair<Operation *, Region *>>()> wrapInGuard,
+ function_ref<void(Operation *, Region *)> moveOutOfRegion,
+ function_ref<LogicalResult()> unwrapGuard) {
size_t numMoved = 0;
for (Region *region : regions) {
LLVM_DEBUG(llvm::dbgs() << "Original loop:\n"
<< *region->getParentOp() << "\n");
+ auto loopSideEffectFreeOrHasOnlyReadSideEffect =
+ loopSideEffectFreeOrHasOnlyReadEffect(region->getParentOp());
+
+ size_t numMovedWithoutGuard = 0;
+
+ FailureOr<std::pair<Operation *, Region *>> ifOpAndRegion = wrapInGuard();
+ Region *loopRegion = region;
+ auto isLoopWrapped = false;
+ if (succeeded(ifOpAndRegion)) {
+ loopRegion = ifOpAndRegion->second;
+ isLoopWrapped = true;
+ }
+
std::queue<Operation *> worklist;
// Add top-level operations in the loop body to the worklist.
- for (Operation &op : region->getOps())
+ for (Operation &op : loopRegion->getOps())
worklist.push(&op);
auto definedOutside = [&](Value value) {
- return isDefinedOutsideRegion(value, region);
+ return isDefinedOutsideRegion(value, loopRegion);
+ };
+
+ auto definedOutsideGuard = [&](Value value) {
+ return isDefinedOutsideRegion(value, loopRegion->getParentRegion());
};
while (!worklist.empty()) {
Operation *op = worklist.front();
worklist.pop();
// Skip ops that have already been moved. Check if the op can be hoisted.
- if (op->getParentRegion() != region)
+ if (op->getParentRegion() != loopRegion)
continue;
LLVM_DEBUG(llvm::dbgs() << "Checking op: " << *op << "\n");
- if (!shouldMoveOutOfRegion(op, region) ||
+
+ if (!shouldMoveOutOfRegion(op, loopRegion) ||
!canBeHoisted(op, definedOutside))
continue;
+ // Can only hoist pure ops (side-effect free) when there is an op with
+ // write side effects in the loop
+ if (!loopSideEffectFreeOrHasOnlyReadSideEffect && !isMemoryEffectFree(op))
+ continue;
LLVM_DEBUG(llvm::dbgs() << "Moving loop-invariant op: " << *op << "\n");
- moveOutOfRegion(op, region);
+
+ auto moveWithoutGuard = isMemoryEffectFree(op) &&
+ !dependsOnGuarded(op, definedOutsideGuard) &&
+ isLoopWrapped;
+ numMovedWithoutGuard += moveWithoutGuard;
+
+ moveOutOfRegion(op, moveWithoutGuard ? loopRegion->getParentRegion()
+ : loopRegion);
++numMoved;
// Since the op has been moved, we need to check its users within the
// top-level of the loop body.
for (Operation *user : op->getUsers())
- if (user->getParentRegion() == region)
+ if (user->getParentRegion() == loopRegion)
worklist.push(user);
}
+
+ // Unwrap the loop if it was wrapped but no ops were moved in the guard.
+ if (isLoopWrapped && numMovedWithoutGuard == numMoved) {
+ auto tripCountCheckUnwrapped = unwrapGuard();
+ if (failed(tripCountCheckUnwrapped))
+ llvm_unreachable("Should not fail unwrapping trip-count check");
+ }
}
return numMoved;
@@ -106,13 +175,18 @@ size_t mlir::moveLoopInvariantCode(
size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) {
return moveLoopInvariantCode(
loopLike.getLoopRegions(),
- [&](Value value, Region *) {
- return loopLike.isDefinedOutsideOfLoop(value);
+ [&](Value value, Region *region) {
+ return !region->isAncestor(value.getParentRegion());
},
[&](Operation *op, Region *) {
- return isMemoryEffectFree(op) && isSpeculatable(op);
+ return isSpeculatable(op) &&
+ (isMemoryEffectFree(op) || hasOnlyReadEffect(op));
+ },
+ [&]() { return loopLike.wrapInTripCountCheck(); },
+ [&](Operation *op, Region *region) {
+ op->moveBefore(region->getParentOp());
},
- [&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); });
+ [&]() { return loopLike.unwrapTripCountCheck(); });
}
namespace {
diff --git a/mlir/test/Transforms/loop-invariant-code-motion.mlir b/mlir/test/Transforms/loop-invariant-code-motion.mlir
index e4c423ce7052bf..6f5cc60c59252c 100644
--- a/mlir/test/Transforms/loop-invariant-code-motion.mlir
+++ b/mlir/test/Transforms/loop-invariant-code-motion.mlir
@@ -593,6 +593,111 @@ func.func @test_recursively_speculatable_op_failure(%lb: index, %ub: index, %ste
return
}
+// CHECK-LABEL: test_speculatable_op_with_read_side_effect_success
+func.func @test_speculatable_op_with_read_side_effect_success(%lb: index, %ub: index, %step: index) -> i32 {
+ // CHECK: arith.subi
+ // CHECK: arith.ceildivsi
+ // CHECK: arith.constant
+ // CHECK: arith.cmpi
+ // CHECK-NEXT: test.always_speculatable_op
+ // CHECK-NEXT: scf.if
+ // CHECK-NEXT: test.speculatable_op_with_memread
+ // CHECK-NEXT: scf.for
+ // CHECK-NOT: test.always_speculatable_op
+ // CHECK-NOT: test.speculatable_op_with_memread
+ %cst_0 = arith.constant 0 : i32
+ %cst_42 = arith.constant dense<42> : tensor<64xi32>
+ %ind_42 = arith.constant 42 : index
+ %sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
+ %always = "test.always_speculatable_op"() : () -> i32
+ %always_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
+ %i_cast = arith.index_cast %i: index to i32
+ %i_sum = arith.addi %acc, %i_cast : i32
+ %test_sum = arith.addi %i_sum, %always_read : i32
+ scf.yield %test_sum : i32
+ }
+ return %sum_result : i32
+}
+
+// CHECK-LABEL: test_speculatable_op_with_read_side_effect_success_with_dependents
+func.func @test_speculatable_op_with_read_side_effect_success_with_dependents(%lb: index, %ub: index, %step: index) -> i32 {
+ // CHECK: arith.subi
+ // CHECK: arith.ceildivsi
+ // CHECK: arith.constant
+ // CHECK: arith.cmpi
+ // CHECK-NEXT: test.always_speculatable_op
+ // CHECK-NEXT: scf.if
+ // CHECK-NEXT: test.speculatable_op_with_memread
+ // CHECK-NEXT: arith.addi
+ // CHECK-NEXT: scf.for
+ // CHECK-NOT: test.always_speculatable_op
+ // CHECK-NOT: test.speculatable_op_with_memread
+ %cst_0 = arith.constant 0 : i32
+ %cst_42 = arith.constant dense<42> : tensor<64xi32>
+ %ind_42 = arith.constant 42 : index
+ %sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
+ %always = "test.always_speculatable_op"() : () -> i32
+ %always_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
+ %add = arith.addi %always_read, %cst_0 : i32
+ %i_cast = arith.index_cast %i: index to i32
+ %i_sum = arith.addi %acc, %i_cast : i32
+ %test_sum = arith.addi %i_sum, %add : i32
+ scf.yield %test_sum : i32
+ }
+ return %sum_result : i32
+}
+
+// CHECK-LABEL: test_speculatable_op_with_read_side_effect_failure_due_to_write
+func.func @test_speculatable_op_with_read_side_effect_failure_due_to_write(%lb: index, %ub: index, %step: index) -> i32 {
+ // CHECK: test.always_speculatable_op
+ // CHECK-NEXT: scf.for
+ // CHECK-NOT: test.always_speculatable_op
+ // CHECK: test.speculatable_op_with_memread
+ // CHECK: test.speculatable_op_with_memwrite
+ %cst_0 = arith.constant 0 : i32
+ %cst_42 = arith.constant dense<42> : tensor<64xi32>
+ %ind_42 = arith.constant 42 : index
+ %sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
+ %always = "test.always_speculatable_op"() : () -> i32
+ %always_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
+ %i_cast = arith.index_cast %i: index to i32
+ %i_sum = arith.addi %acc, %i_cast : i32
+ %test_sum = arith.addi %i_sum, %always_read : i32
+ %always_write = "test.speculatable_op_with_memwrite"(%cst_42) : (tensor<64xi32>) -> i32
+ scf.yield %test_sum : i32
+ }
+ return %sum_result : i32
+}
+
+// CHECK-LABEL: test_speculatable_op_with_read_side_effect_failure_due_to_nested_write
+func.func @test_speculatable_op_with_read_side_effect_failure_due_to_nested_write(%lb: index, %ub: index, %step: index) -> i32 {
+ // CHECK: test.always_speculatable_op
+ // CHECK-NEXT: scf.for
+ // CHECK-NOT: test.always_speculatable_op
+ // CHECK: test.speculatable_op_with_memread
+ // CHECK: scf.for
+ // CHECK: scf.if
+ // CHECK: test.speculatable_op_with_memwrite
+ %cst_0 = arith.constant 0 : i32
+ %cst_42 = arith.constant dense<42> : tensor<64xi32>
+ %ind_42 = arith.constant 42 : index
+ %sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
+ %always = "test.always_speculatable_op"() : () -> i32
+ %always_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
+ %i_cast = arith.index_cast %i: index to i32
+ %i_sum = arith....
[truncated]
|
This is needed by Triton to address pytorch/pytorch#134535, cc @ThomasRaoux |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't looked at everything in details but overall it seems a bit too ad hoc and breaks the existing interfaces/separation of concerns
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Haven't look into details, but i believe we should reconsider the whole approach. The interface design and the modifications to the pass feel like overkill.
mlir/lib/Dialect/SCF/IR/SCF.cpp
Outdated
|
||
IRRewriter rewriter(ifOp->getContext()); | ||
OpBuilder::InsertionGuard insertGuard(rewriter); | ||
rewriter.setInsertionPoint(ifOp); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This requires that this function be called immediately after wrapInTripCountCheck. How can this be guaranteed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By the check here.
This checks if this region is wrapped and there is no op that needs to be guarded when being hoisted, then it unwraps. This function is called only then.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the reply, I'll look into details later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Functions implemented by interfaces should be callable from anywhere. If there is a requirement for functions of multiple interfaces to be called in a specific order, it is recommended not to use interfaces.
2b57fa2
to
0e596fd
Compare
// Otherwise, if the op does not implement the memory effect interface and | ||
// it does not have recursive side effects, then it cannot be known that the | ||
// op is moveable. | ||
return false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This implementation is not robust and can hardly handle any operations with regions, because their terminators are inherently memory effect free.
mlir/lib/Dialect/SCF/IR/SCF.cpp
Outdated
@@ -395,6 +395,60 @@ std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() { | |||
|
|||
std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); } | |||
|
|||
FailureOr<std::pair<Operation *, Region *>> ForOp::wrapInTripCountCheck() { | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Strip blank lines
mlir/lib/Dialect/SCF/IR/SCF.cpp
Outdated
|
||
IRRewriter rewriter(ifOp->getContext()); | ||
OpBuilder::InsertionGuard insertGuard(rewriter); | ||
rewriter.setInsertionPoint(ifOp); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Functions implemented by interfaces should be callable from anywhere. If there is a requirement for functions of multiple interfaces to be called in a specific order, it is recommended not to use interfaces.
|
||
LLVM_DEBUG(llvm::dbgs() << "Moving loop-invariant op: " << *op << "\n"); | ||
moveOutOfRegion(op, region); | ||
|
||
auto moveWithoutGuard = isMemoryEffectFree(op) && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why auto
auto tripCountCheckUnwrapped = unwrapGuard(); | ||
if (failed(tripCountCheckUnwrapped)) | ||
llvm_unreachable("Should not fail unwrapping trip-count check"); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is unwarp part necessary?Can we use canonicalize to achieve that?
…rapped by a guard
I changed the approach as we discussed on Speculative LICM?. Thanks for the discussion! @ThomasRaoux, @cxy-1993, @htyu Following is different from the initial change:
so that CSE and canonicalizer can do their job to get the following instead:
|
Ping |
1 similar comment
Ping |
@qedawkins IIRC, you wanted to implement (or even implemented) something similar in the past (except for guarding non-reads too)? |
I added a simple check to restrict LICM to loops with >= 1 trip count (even for speculatable ops): https://github.com/iree-org/iree/blob/278e63ad7fb790629c329ed0e12f39940ef75916/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp#L485 The idea was that even side effect free ops could still have execution cost so hoisting wasn't safe (without a guard). It wasn't a very sophisticated check and doesn't ever generate a guard. The direction here is interesting for sure! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have any objection to this PR but I find the interface a bit leaky and the transformation a bit opinionated for it to be live here.
I'm not actively maintaining this part of the code so if someone who does could weight in that would be better.
Otherwise like I said I don't any objection of this going in.
Pinging @cxy-1993 |
Enable ops with only read side effects in scf.for to be hoisted with a scf.if guard that checks against the trip count
This patch takes a step towards a less conservative LICM in MLIR as discussed in the following discourse thread:
Speculative LICM?
This patch in particular does the following: