Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir] Enable LICM for ops with only read side effects in scf.for #120302

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

ardaunal
Copy link
Contributor

Enable ops with only read side effects in scf.for to be hoisted with a scf.if guard that checks against the trip count

This patch takes a step towards a less conservative LICM in MLIR as discussed in the following discourse thread:

Speculative LICM?

This patch in particular does the following:

  1. Relaxes the original constraint for hoisting that only hoists ops without any side effects. This patch also allows the ops with only read side effects to be hoisted into an scf.if guard only if every op in the loop or its nested regions is side-effect free or has only read side effects. This scf.if guard wraps the original scf.for and checks for trip_count > 0.
  2. To support this, two new interface methods are added to LoopLikeInterface: wrapInTripCountCheck and unwrapTripCountCheck. Implementation starts with wrapping the scf.for loop into scf.if guard using wrapInTripCountCheck and if there is no op hoisted into the this guard after we are done processing the worklist, it unwraps the guard by calling unwrapTripCountCheck.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir mlir:scf labels Dec 17, 2024
@llvmbot
Copy link
Member

llvmbot commented Dec 17, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Arda Unal (ardaunal)

Changes

Enable ops with only read side effects in scf.for to be hoisted with a scf.if guard that checks against the trip count

This patch takes a step towards a less conservative LICM in MLIR as discussed in the following discourse thread:

Speculative LICM?

This patch in particular does the following:

  1. Relaxes the original constraint for hoisting that only hoists ops without any side effects. This patch also allows the ops with only read side effects to be hoisted into an scf.if guard only if every op in the loop or its nested regions is side-effect free or has only read side effects. This scf.if guard wraps the original scf.for and checks for trip_count > 0.
  2. To support this, two new interface methods are added to LoopLikeInterface: wrapInTripCountCheck and unwrapTripCountCheck. Implementation starts with wrapping the scf.for loop into scf.if guard using wrapInTripCountCheck and if there is no op hoisted into the this guard after we are done processing the worklist, it unwraps the guard by calling unwrapTripCountCheck.

Patch is 22.40 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/120302.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SCF/IR/SCFOps.td (+2-1)
  • (modified) mlir/include/mlir/Interfaces/LoopLikeInterface.td (+20)
  • (modified) mlir/include/mlir/Interfaces/SideEffectInterfaces.h (+4)
  • (modified) mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h (+9-3)
  • (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+79-3)
  • (modified) mlir/lib/Interfaces/SideEffectInterfaces.cpp (+7)
  • (modified) mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp (+85-11)
  • (modified) mlir/test/Transforms/loop-invariant-code-motion.mlir (+105)
  • (modified) mlir/test/lib/Dialect/Test/TestOps.td (+44)
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 23c597a1ca5108..b54df8e3ef313d 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -139,6 +139,7 @@ def ForOp : SCF_Op<"for",
         "getLoopInductionVars", "getLoopLowerBounds", "getLoopSteps",
         "getLoopUpperBounds", "getYieldedValuesMutable",
         "promoteIfSingleIteration", "replaceWithAdditionalYields",
+        "wrapInTripCountCheck", "unwrapTripCountCheck",
         "yieldTiledValuesAndReplace"]>,
        AllTypesMatch<["lowerBound", "upperBound", "step"]>,
        ConditionallySpeculatable,
@@ -302,7 +303,7 @@ def ForallOp : SCF_Op<"forall", [
        AttrSizedOperandSegments,
        AutomaticAllocationScope,
        DeclareOpInterfaceMethods<LoopLikeOpInterface,
-          ["getInitsMutable", "getRegionIterArgs", "getLoopInductionVars", 
+          ["getInitsMutable", "getRegionIterArgs", "getLoopInductionVars",
            "getLoopLowerBounds", "getLoopUpperBounds", "getLoopSteps",
            "promoteIfSingleIteration", "yieldTiledValuesAndReplace"]>,
        RecursiveMemoryEffects,
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index c6bffe347419e5..831830130b0ddc 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -79,6 +79,26 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
       /*methodBody=*/"",
       /*defaultImplementation=*/"op->moveBefore($_op);"
     >,
+    InterfaceMethod<[{
+        Wraps the loop into a trip-count check.
+      }],
+      /*retTy=*/"FailureOr<std::pair<::mlir::Operation *, ::mlir::Region *>>",
+      /*methodName=*/"wrapInTripCountCheck",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/"return ::mlir::failure();"
+    >,
+    InterfaceMethod<[{
+        Unwraps the trip-count check.
+      }],
+      /*retTy=*/"::llvm::LogicalResult",
+      /*methodName=*/"unwrapTripCountCheck",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return ::mlir::failure();
+      }]
+    >,
     InterfaceMethod<[{
         Promotes the loop body to its containing block if the loop is known to
         have a single iteration. Returns "success" if the promotion was
diff --git a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
index aef7ec622fe4f8..1a7f66e2234949 100644
--- a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
+++ b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
@@ -433,6 +433,10 @@ bool wouldOpBeTriviallyDead(Operation *op);
 /// conditions are satisfied.
 bool isMemoryEffectFree(Operation *op);
 
+/// Returns true if the given operation implements `MemoryEffectOpInterface` and
+/// has only read effects.
+bool hasOnlyReadEffect(Operation *op);
+
 /// Returns the side effects of an operation. If the operation has
 /// RecursiveMemoryEffects, include all side effects of child operations.
 ///
diff --git a/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h b/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
index 3ceef44d799e89..ae6719abe79c00 100644
--- a/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
+++ b/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
@@ -48,15 +48,19 @@ class Value;
 /// }
 /// ```
 ///
-/// Users must supply three callbacks.
+/// Users must supply five callbacks.
 ///
 /// - `isDefinedOutsideRegion` returns true if the given value is invariant with
 ///   respect to the given region. A common implementation might be:
 ///   `value.getParentRegion()->isProperAncestor(region)`.
 /// - `shouldMoveOutOfRegion` returns true if the provided operation can be
-///   moved of the given region, e.g. if it is side-effect free.
+///   moved of the given region, e.g. if it is side-effect free or has only read
+///   side effects.
+/// - `wrapInGuard` wraps the given operation in a trip-count check guard.
 /// - `moveOutOfRegion` moves the operation out of the given region. A common
 ///   implementation might be: `op->moveBefore(region->getParentOp())`.
+/// - `unwrapGuard` unwraps the trip-count check if there is no op guarded by
+///   this check.
 ///
 /// An operation is moved if all of its operands satisfy
 /// `isDefinedOutsideRegion` and it satisfies `shouldMoveOutOfRegion`.
@@ -66,7 +70,9 @@ size_t moveLoopInvariantCode(
     ArrayRef<Region *> regions,
     function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
     function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
-    function_ref<void(Operation *, Region *)> moveOutOfRegion);
+    function_ref<FailureOr<std::pair<Operation *, Region *>>()> wrapInGuard,
+    function_ref<void(Operation *, Region *)> moveOutOfRegion,
+    function_ref<LogicalResult()> unwrapGuard);
 
 /// Move side-effect free loop invariant code out of a loop-like op using
 /// methods provided by the interface.
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index eded1c394f126c..148617c84547c7 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -395,6 +395,83 @@ std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
 
 std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
 
+FailureOr<std::pair<Operation *, Region *>> ForOp::wrapInTripCountCheck() {
+  auto lowerBound = this->getLowerBound();
+  auto upperBound = this->getUpperBound();
+  auto step = this->getStep();
+  auto initArgs = this->getInitArgs();
+  auto results = this->getResults();
+  auto loc = this->getLoc();
+
+  IRRewriter rewriter(this->getContext());
+  OpBuilder::InsertionGuard insertGuard(rewriter);
+  rewriter.setInsertionPointAfter(this->getOperation());
+
+  // Form the trip count calculation
+  auto subOp = rewriter.create<arith::SubIOp>(loc, upperBound, lowerBound);
+  auto ceilDivSIOp = rewriter.create<arith::CeilDivSIOp>(loc, subOp, step);
+  Value zero;
+  if (upperBound.getType().isIndex()) {
+    zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  } else {
+    zero = rewriter.create<arith::ConstantIntOp>(
+        loc, 0,
+        /*width=*/
+        upperBound.getType().getIntOrFloatBitWidth());
+  }
+  auto cmpIOp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
+                                               ceilDivSIOp, zero);
+  scf::YieldOp yieldInThen;
+  // Create the trip-count check
+  auto ifOp = rewriter.create<scf::IfOp>(
+      loc, cmpIOp,
+      [&](OpBuilder &builder, Location loc) {
+        yieldInThen = builder.create<scf::YieldOp>(loc, results);
+      },
+      [&](OpBuilder &builder, Location loc) {
+        builder.create<scf::YieldOp>(loc, initArgs);
+      });
+
+  for (auto [forOpResult, ifOpResult] : llvm::zip(results, ifOp.getResults()))
+    rewriter.replaceAllUsesExcept(forOpResult, ifOpResult, yieldInThen);
+  // Move the scf.for into the then block
+  rewriter.moveOpBefore(this->getOperation(), yieldInThen);
+  return std::make_pair(ifOp.getOperation(), &this->getRegion());
+}
+
+LogicalResult ForOp::unwrapTripCountCheck() {
+  auto ifOp = (*this)->getParentRegion()->getParentOp();
+  if (!isa<scf::IfOp>(ifOp))
+    return failure();
+
+  auto wrappedForOp = this->getOperation();
+
+  IRRewriter rewriter(ifOp->getContext());
+  OpBuilder::InsertionGuard insertGuard(rewriter);
+  rewriter.setInsertionPoint(ifOp);
+
+  auto cmpOp = ifOp->getOperand(0).getDefiningOp();
+  auto ceilDivSIOp = cmpOp->getOperand(0).getDefiningOp();
+  auto zero = cmpOp->getOperand(1).getDefiningOp();
+  auto subOp = ceilDivSIOp->getOperand(0).getDefiningOp();
+  if (!isa<arith::CmpIOp>(cmpOp) || !isa<arith::CeilDivSIOp>(ceilDivSIOp) ||
+      !isa<arith::SubIOp>(subOp))
+    return failure();
+
+  rewriter.moveOpBefore(wrappedForOp, ifOp);
+
+  for (auto [forOpResult, ifOpResult] :
+       llvm::zip(wrappedForOp->getResults(), ifOp->getResults()))
+    rewriter.replaceAllUsesWith(ifOpResult, forOpResult);
+
+  rewriter.eraseOp(ifOp);
+  rewriter.eraseOp(cmpOp);
+  rewriter.eraseOp(zero);
+  rewriter.eraseOp(ceilDivSIOp);
+  rewriter.eraseOp(subOp);
+  return success();
+}
+
 /// Promotes the loop body of a forOp to its containing block if the forOp
 /// it can be determined that the loop has a single iteration.
 LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
@@ -3397,9 +3474,8 @@ ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
 
   if (functionType.getNumInputs() != operands.size()) {
     return parser.emitError(typeLoc)
-           << "expected as many input types as operands "
-           << "(expected " << operands.size() << " got "
-           << functionType.getNumInputs() << ")";
+           << "expected as many input types as operands " << "(expected "
+           << operands.size() << " got " << functionType.getNumInputs() << ")";
   }
 
   // Resolve input operands.
diff --git a/mlir/lib/Interfaces/SideEffectInterfaces.cpp b/mlir/lib/Interfaces/SideEffectInterfaces.cpp
index c9feb001a19844..f45d5f3d227407 100644
--- a/mlir/lib/Interfaces/SideEffectInterfaces.cpp
+++ b/mlir/lib/Interfaces/SideEffectInterfaces.cpp
@@ -306,6 +306,13 @@ bool mlir::wouldOpBeTriviallyDead(Operation *op) {
   return wouldOpBeTriviallyDeadImpl(op);
 }
 
+bool mlir::hasOnlyReadEffect(Operation *op) {
+  if (auto memEffects = dyn_cast<MemoryEffectOpInterface>(op)) {
+    return memEffects.onlyHasEffect<MemoryEffects::Read>();
+  }
+  return false;
+}
+
 bool mlir::isMemoryEffectFree(Operation *op) {
   if (auto memInterface = dyn_cast<MemoryEffectOpInterface>(op)) {
     if (!memInterface.hasNoEffect())
diff --git a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
index 7460746934a78c..1bdc74dc2a170a 100644
--- a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
@@ -56,48 +56,117 @@ static bool canBeHoisted(Operation *op,
       op, [&](OpOperand &operand) { return definedOutside(operand.get()); });
 }
 
+static bool dependsOnGuarded(Operation *op,
+                             function_ref<bool(OpOperand &)> condition) {
+  auto walkFn = [&](Operation *child) {
+    for (OpOperand &operand : child->getOpOperands()) {
+      if (!condition(operand))
+        return WalkResult::interrupt();
+    }
+    return WalkResult::advance();
+  };
+  return op->walk(walkFn).wasInterrupted();
+}
+
+static bool dependsOnGuarded(Operation *op,
+                             function_ref<bool(Value)> definedOutsideGuard) {
+  return dependsOnGuarded(op, [&](OpOperand &operand) {
+    return definedOutsideGuard(operand.get());
+  });
+}
+
+static bool loopSideEffectFreeOrHasOnlyReadEffect(Operation *loop) {
+  for (auto &region : loop->getRegions()) {
+    for (auto &block : region.getBlocks()) {
+      for (Operation &op : block.getOperations()) {
+        if (!isMemoryEffectFree(&op) && !hasOnlyReadEffect(&op))
+          return false;
+      }
+    }
+  }
+  return true;
+}
+
 size_t mlir::moveLoopInvariantCode(
     ArrayRef<Region *> regions,
     function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
     function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
-    function_ref<void(Operation *, Region *)> moveOutOfRegion) {
+    function_ref<FailureOr<std::pair<Operation *, Region *>>()> wrapInGuard,
+    function_ref<void(Operation *, Region *)> moveOutOfRegion,
+    function_ref<LogicalResult()> unwrapGuard) {
   size_t numMoved = 0;
 
   for (Region *region : regions) {
     LLVM_DEBUG(llvm::dbgs() << "Original loop:\n"
                             << *region->getParentOp() << "\n");
 
+    auto loopSideEffectFreeOrHasOnlyReadSideEffect =
+        loopSideEffectFreeOrHasOnlyReadEffect(region->getParentOp());
+
+    size_t numMovedWithoutGuard = 0;
+
+    FailureOr<std::pair<Operation *, Region *>> ifOpAndRegion = wrapInGuard();
+    Region *loopRegion = region;
+    auto isLoopWrapped = false;
+    if (succeeded(ifOpAndRegion)) {
+      loopRegion = ifOpAndRegion->second;
+      isLoopWrapped = true;
+    }
+
     std::queue<Operation *> worklist;
     // Add top-level operations in the loop body to the worklist.
-    for (Operation &op : region->getOps())
+    for (Operation &op : loopRegion->getOps())
       worklist.push(&op);
 
     auto definedOutside = [&](Value value) {
-      return isDefinedOutsideRegion(value, region);
+      return isDefinedOutsideRegion(value, loopRegion);
+    };
+
+    auto definedOutsideGuard = [&](Value value) {
+      return isDefinedOutsideRegion(value, loopRegion->getParentRegion());
     };
 
     while (!worklist.empty()) {
       Operation *op = worklist.front();
       worklist.pop();
       // Skip ops that have already been moved. Check if the op can be hoisted.
-      if (op->getParentRegion() != region)
+      if (op->getParentRegion() != loopRegion)
         continue;
 
       LLVM_DEBUG(llvm::dbgs() << "Checking op: " << *op << "\n");
-      if (!shouldMoveOutOfRegion(op, region) ||
+
+      if (!shouldMoveOutOfRegion(op, loopRegion) ||
           !canBeHoisted(op, definedOutside))
         continue;
+      // Can only hoist pure ops (side-effect free) when there is an op with
+      // write side effects in the loop
+      if (!loopSideEffectFreeOrHasOnlyReadSideEffect && !isMemoryEffectFree(op))
+        continue;
 
       LLVM_DEBUG(llvm::dbgs() << "Moving loop-invariant op: " << *op << "\n");
-      moveOutOfRegion(op, region);
+
+      auto moveWithoutGuard = isMemoryEffectFree(op) &&
+                              !dependsOnGuarded(op, definedOutsideGuard) &&
+                              isLoopWrapped;
+      numMovedWithoutGuard += moveWithoutGuard;
+
+      moveOutOfRegion(op, moveWithoutGuard ? loopRegion->getParentRegion()
+                                           : loopRegion);
       ++numMoved;
 
       // Since the op has been moved, we need to check its users within the
       // top-level of the loop body.
       for (Operation *user : op->getUsers())
-        if (user->getParentRegion() == region)
+        if (user->getParentRegion() == loopRegion)
           worklist.push(user);
     }
+
+    // Unwrap the loop if it was wrapped but no ops were moved in the guard.
+    if (isLoopWrapped && numMovedWithoutGuard == numMoved) {
+      auto tripCountCheckUnwrapped = unwrapGuard();
+      if (failed(tripCountCheckUnwrapped))
+        llvm_unreachable("Should not fail unwrapping trip-count check");
+    }
   }
 
   return numMoved;
@@ -106,13 +175,18 @@ size_t mlir::moveLoopInvariantCode(
 size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) {
   return moveLoopInvariantCode(
       loopLike.getLoopRegions(),
-      [&](Value value, Region *) {
-        return loopLike.isDefinedOutsideOfLoop(value);
+      [&](Value value, Region *region) {
+        return !region->isAncestor(value.getParentRegion());
       },
       [&](Operation *op, Region *) {
-        return isMemoryEffectFree(op) && isSpeculatable(op);
+        return isSpeculatable(op) &&
+               (isMemoryEffectFree(op) || hasOnlyReadEffect(op));
+      },
+      [&]() { return loopLike.wrapInTripCountCheck(); },
+      [&](Operation *op, Region *region) {
+        op->moveBefore(region->getParentOp());
       },
-      [&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); });
+      [&]() { return loopLike.unwrapTripCountCheck(); });
 }
 
 namespace {
diff --git a/mlir/test/Transforms/loop-invariant-code-motion.mlir b/mlir/test/Transforms/loop-invariant-code-motion.mlir
index e4c423ce7052bf..6f5cc60c59252c 100644
--- a/mlir/test/Transforms/loop-invariant-code-motion.mlir
+++ b/mlir/test/Transforms/loop-invariant-code-motion.mlir
@@ -593,6 +593,111 @@ func.func @test_recursively_speculatable_op_failure(%lb: index, %ub: index, %ste
   return
 }
 
+// CHECK-LABEL: test_speculatable_op_with_read_side_effect_success
+func.func @test_speculatable_op_with_read_side_effect_success(%lb: index, %ub: index, %step: index) -> i32 {
+  // CHECK: arith.subi
+  // CHECK: arith.ceildivsi
+  // CHECK: arith.constant
+  // CHECK: arith.cmpi
+  // CHECK-NEXT: test.always_speculatable_op
+  // CHECK-NEXT: scf.if
+  // CHECK-NEXT: test.speculatable_op_with_memread
+  // CHECK-NEXT: scf.for
+  // CHECK-NOT: test.always_speculatable_op
+  // CHECK-NOT: test.speculatable_op_with_memread
+  %cst_0 = arith.constant 0 : i32
+  %cst_42 = arith.constant dense<42> : tensor<64xi32>
+  %ind_42 = arith.constant 42 : index
+  %sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
+    %always = "test.always_speculatable_op"() : () -> i32
+    %always_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
+    %i_cast = arith.index_cast %i: index to i32
+    %i_sum = arith.addi %acc, %i_cast : i32
+    %test_sum = arith.addi %i_sum, %always_read : i32
+    scf.yield %test_sum : i32
+  }
+  return %sum_result : i32
+}
+
+// CHECK-LABEL: test_speculatable_op_with_read_side_effect_success_with_dependents
+func.func @test_speculatable_op_with_read_side_effect_success_with_dependents(%lb: index, %ub: index, %step: index) -> i32 {
+  // CHECK: arith.subi
+  // CHECK: arith.ceildivsi
+  // CHECK: arith.constant
+  // CHECK: arith.cmpi
+  // CHECK-NEXT: test.always_speculatable_op
+  // CHECK-NEXT: scf.if
+  // CHECK-NEXT: test.speculatable_op_with_memread
+  // CHECK-NEXT: arith.addi
+  // CHECK-NEXT: scf.for
+  // CHECK-NOT: test.always_speculatable_op
+  // CHECK-NOT: test.speculatable_op_with_memread
+  %cst_0 = arith.constant 0 : i32
+  %cst_42 = arith.constant dense<42> : tensor<64xi32>
+  %ind_42 = arith.constant 42 : index
+  %sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
+    %always = "test.always_speculatable_op"() : () -> i32
+    %always_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
+    %add = arith.addi %always_read, %cst_0 : i32
+    %i_cast = arith.index_cast %i: index to i32
+    %i_sum = arith.addi %acc, %i_cast : i32
+    %test_sum = arith.addi %i_sum, %add : i32
+    scf.yield %test_sum : i32
+  }
+  return %sum_result : i32
+}
+
+// CHECK-LABEL: test_speculatable_op_with_read_side_effect_failure_due_to_write
+func.func @test_speculatable_op_with_read_side_effect_failure_due_to_write(%lb: index, %ub: index, %step: index) -> i32 {
+  // CHECK: test.always_speculatable_op
+  // CHECK-NEXT: scf.for
+  // CHECK-NOT: test.always_speculatable_op
+  // CHECK: test.speculatable_op_with_memread
+  // CHECK: test.speculatable_op_with_memwrite
+  %cst_0 = arith.constant 0 : i32
+  %cst_42 = arith.constant dense<42> : tensor<64xi32>
+  %ind_42 = arith.constant 42 : index
+  %sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
+    %always = "test.always_speculatable_op"() : () -> i32
+    %always_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
+    %i_cast = arith.index_cast %i: index to i32
+    %i_sum = arith.addi %acc, %i_cast : i32
+    %test_sum = arith.addi %i_sum, %always_read : i32
+    %always_write = "test.speculatable_op_with_memwrite"(%cst_42) : (tensor<64xi32>) -> i32
+    scf.yield %test_sum : i32
+  }
+  return %sum_result : i32
+}
+
+// CHECK-LABEL: test_speculatable_op_with_read_side_effect_failure_due_to_nested_write
+func.func @test_speculatable_op_with_read_side_effect_failure_due_to_nested_write(%lb: index, %ub: index, %step: index) -> i32 {
+  // CHECK: test.always_speculatable_op
+  // CHECK-NEXT: scf.for
+  // CHECK-NOT: test.always_speculatable_op
+  // CHECK: test.speculatable_op_with_memread
+  // CHECK: scf.for
+  // CHECK: scf.if
+  // CHECK: test.speculatable_op_with_memwrite
+  %cst_0 = arith.constant 0 : i32
+  %cst_42 = arith.constant dense<42> : tensor<64xi32>
+  %ind_42 = arith.constant 42 : index
+  %sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
+    %always = "test.always_speculatable_op"() : () -> i32
+    %always_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
+    %i_cast = arith.index_cast %i: index to i32
+    %i_sum = arith....
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Dec 17, 2024

@llvm/pr-subscribers-mlir-scf

Author: Arda Unal (ardaunal)

Changes

Enable ops with only read side effects in scf.for to be hoisted with a scf.if guard that checks against the trip count

This patch takes a step towards a less conservative LICM in MLIR as discussed in the following discourse thread:

Speculative LICM?

This patch in particular does the following:

  1. Relaxes the original constraint for hoisting that only hoists ops without any side effects. This patch also allows the ops with only read side effects to be hoisted into an scf.if guard only if every op in the loop or its nested regions is side-effect free or has only read side effects. This scf.if guard wraps the original scf.for and checks for trip_count > 0.
  2. To support this, two new interface methods are added to LoopLikeInterface: wrapInTripCountCheck and unwrapTripCountCheck. Implementation starts with wrapping the scf.for loop into scf.if guard using wrapInTripCountCheck and if there is no op hoisted into the this guard after we are done processing the worklist, it unwraps the guard by calling unwrapTripCountCheck.

Patch is 22.40 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/120302.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SCF/IR/SCFOps.td (+2-1)
  • (modified) mlir/include/mlir/Interfaces/LoopLikeInterface.td (+20)
  • (modified) mlir/include/mlir/Interfaces/SideEffectInterfaces.h (+4)
  • (modified) mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h (+9-3)
  • (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+79-3)
  • (modified) mlir/lib/Interfaces/SideEffectInterfaces.cpp (+7)
  • (modified) mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp (+85-11)
  • (modified) mlir/test/Transforms/loop-invariant-code-motion.mlir (+105)
  • (modified) mlir/test/lib/Dialect/Test/TestOps.td (+44)
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 23c597a1ca5108..b54df8e3ef313d 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -139,6 +139,7 @@ def ForOp : SCF_Op<"for",
         "getLoopInductionVars", "getLoopLowerBounds", "getLoopSteps",
         "getLoopUpperBounds", "getYieldedValuesMutable",
         "promoteIfSingleIteration", "replaceWithAdditionalYields",
+        "wrapInTripCountCheck", "unwrapTripCountCheck",
         "yieldTiledValuesAndReplace"]>,
        AllTypesMatch<["lowerBound", "upperBound", "step"]>,
        ConditionallySpeculatable,
@@ -302,7 +303,7 @@ def ForallOp : SCF_Op<"forall", [
        AttrSizedOperandSegments,
        AutomaticAllocationScope,
        DeclareOpInterfaceMethods<LoopLikeOpInterface,
-          ["getInitsMutable", "getRegionIterArgs", "getLoopInductionVars", 
+          ["getInitsMutable", "getRegionIterArgs", "getLoopInductionVars",
            "getLoopLowerBounds", "getLoopUpperBounds", "getLoopSteps",
            "promoteIfSingleIteration", "yieldTiledValuesAndReplace"]>,
        RecursiveMemoryEffects,
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index c6bffe347419e5..831830130b0ddc 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -79,6 +79,26 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
       /*methodBody=*/"",
       /*defaultImplementation=*/"op->moveBefore($_op);"
     >,
+    InterfaceMethod<[{
+        Wraps the loop into a trip-count check.
+      }],
+      /*retTy=*/"FailureOr<std::pair<::mlir::Operation *, ::mlir::Region *>>",
+      /*methodName=*/"wrapInTripCountCheck",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/"return ::mlir::failure();"
+    >,
+    InterfaceMethod<[{
+        Unwraps the trip-count check.
+      }],
+      /*retTy=*/"::llvm::LogicalResult",
+      /*methodName=*/"unwrapTripCountCheck",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return ::mlir::failure();
+      }]
+    >,
     InterfaceMethod<[{
         Promotes the loop body to its containing block if the loop is known to
         have a single iteration. Returns "success" if the promotion was
diff --git a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
index aef7ec622fe4f8..1a7f66e2234949 100644
--- a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
+++ b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
@@ -433,6 +433,10 @@ bool wouldOpBeTriviallyDead(Operation *op);
 /// conditions are satisfied.
 bool isMemoryEffectFree(Operation *op);
 
+/// Returns true if the given operation implements `MemoryEffectOpInterface` and
+/// has only read effects.
+bool hasOnlyReadEffect(Operation *op);
+
 /// Returns the side effects of an operation. If the operation has
 /// RecursiveMemoryEffects, include all side effects of child operations.
 ///
diff --git a/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h b/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
index 3ceef44d799e89..ae6719abe79c00 100644
--- a/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
+++ b/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
@@ -48,15 +48,19 @@ class Value;
 /// }
 /// ```
 ///
-/// Users must supply three callbacks.
+/// Users must supply five callbacks.
 ///
 /// - `isDefinedOutsideRegion` returns true if the given value is invariant with
 ///   respect to the given region. A common implementation might be:
 ///   `value.getParentRegion()->isProperAncestor(region)`.
 /// - `shouldMoveOutOfRegion` returns true if the provided operation can be
-///   moved of the given region, e.g. if it is side-effect free.
+///   moved of the given region, e.g. if it is side-effect free or has only read
+///   side effects.
+/// - `wrapInGuard` wraps the given operation in a trip-count check guard.
 /// - `moveOutOfRegion` moves the operation out of the given region. A common
 ///   implementation might be: `op->moveBefore(region->getParentOp())`.
+/// - `unwrapGuard` unwraps the trip-count check if there is no op guarded by
+///   this check.
 ///
 /// An operation is moved if all of its operands satisfy
 /// `isDefinedOutsideRegion` and it satisfies `shouldMoveOutOfRegion`.
@@ -66,7 +70,9 @@ size_t moveLoopInvariantCode(
     ArrayRef<Region *> regions,
     function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
     function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
-    function_ref<void(Operation *, Region *)> moveOutOfRegion);
+    function_ref<FailureOr<std::pair<Operation *, Region *>>()> wrapInGuard,
+    function_ref<void(Operation *, Region *)> moveOutOfRegion,
+    function_ref<LogicalResult()> unwrapGuard);
 
 /// Move side-effect free loop invariant code out of a loop-like op using
 /// methods provided by the interface.
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index eded1c394f126c..148617c84547c7 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -395,6 +395,83 @@ std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
 
 std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
 
+FailureOr<std::pair<Operation *, Region *>> ForOp::wrapInTripCountCheck() {
+  auto lowerBound = this->getLowerBound();
+  auto upperBound = this->getUpperBound();
+  auto step = this->getStep();
+  auto initArgs = this->getInitArgs();
+  auto results = this->getResults();
+  auto loc = this->getLoc();
+
+  IRRewriter rewriter(this->getContext());
+  OpBuilder::InsertionGuard insertGuard(rewriter);
+  rewriter.setInsertionPointAfter(this->getOperation());
+
+  // Form the trip count calculation
+  auto subOp = rewriter.create<arith::SubIOp>(loc, upperBound, lowerBound);
+  auto ceilDivSIOp = rewriter.create<arith::CeilDivSIOp>(loc, subOp, step);
+  Value zero;
+  if (upperBound.getType().isIndex()) {
+    zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  } else {
+    zero = rewriter.create<arith::ConstantIntOp>(
+        loc, 0,
+        /*width=*/
+        upperBound.getType().getIntOrFloatBitWidth());
+  }
+  auto cmpIOp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
+                                               ceilDivSIOp, zero);
+  scf::YieldOp yieldInThen;
+  // Create the trip-count check
+  auto ifOp = rewriter.create<scf::IfOp>(
+      loc, cmpIOp,
+      [&](OpBuilder &builder, Location loc) {
+        yieldInThen = builder.create<scf::YieldOp>(loc, results);
+      },
+      [&](OpBuilder &builder, Location loc) {
+        builder.create<scf::YieldOp>(loc, initArgs);
+      });
+
+  for (auto [forOpResult, ifOpResult] : llvm::zip(results, ifOp.getResults()))
+    rewriter.replaceAllUsesExcept(forOpResult, ifOpResult, yieldInThen);
+  // Move the scf.for into the then block
+  rewriter.moveOpBefore(this->getOperation(), yieldInThen);
+  return std::make_pair(ifOp.getOperation(), &this->getRegion());
+}
+
+LogicalResult ForOp::unwrapTripCountCheck() {
+  auto ifOp = (*this)->getParentRegion()->getParentOp();
+  if (!isa<scf::IfOp>(ifOp))
+    return failure();
+
+  auto wrappedForOp = this->getOperation();
+
+  IRRewriter rewriter(ifOp->getContext());
+  OpBuilder::InsertionGuard insertGuard(rewriter);
+  rewriter.setInsertionPoint(ifOp);
+
+  auto cmpOp = ifOp->getOperand(0).getDefiningOp();
+  auto ceilDivSIOp = cmpOp->getOperand(0).getDefiningOp();
+  auto zero = cmpOp->getOperand(1).getDefiningOp();
+  auto subOp = ceilDivSIOp->getOperand(0).getDefiningOp();
+  if (!isa<arith::CmpIOp>(cmpOp) || !isa<arith::CeilDivSIOp>(ceilDivSIOp) ||
+      !isa<arith::SubIOp>(subOp))
+    return failure();
+
+  rewriter.moveOpBefore(wrappedForOp, ifOp);
+
+  for (auto [forOpResult, ifOpResult] :
+       llvm::zip(wrappedForOp->getResults(), ifOp->getResults()))
+    rewriter.replaceAllUsesWith(ifOpResult, forOpResult);
+
+  rewriter.eraseOp(ifOp);
+  rewriter.eraseOp(cmpOp);
+  rewriter.eraseOp(zero);
+  rewriter.eraseOp(ceilDivSIOp);
+  rewriter.eraseOp(subOp);
+  return success();
+}
+
 /// Promotes the loop body of a forOp to its containing block if the forOp
 /// it can be determined that the loop has a single iteration.
 LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
@@ -3397,9 +3474,8 @@ ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
 
   if (functionType.getNumInputs() != operands.size()) {
     return parser.emitError(typeLoc)
-           << "expected as many input types as operands "
-           << "(expected " << operands.size() << " got "
-           << functionType.getNumInputs() << ")";
+           << "expected as many input types as operands " << "(expected "
+           << operands.size() << " got " << functionType.getNumInputs() << ")";
   }
 
   // Resolve input operands.
diff --git a/mlir/lib/Interfaces/SideEffectInterfaces.cpp b/mlir/lib/Interfaces/SideEffectInterfaces.cpp
index c9feb001a19844..f45d5f3d227407 100644
--- a/mlir/lib/Interfaces/SideEffectInterfaces.cpp
+++ b/mlir/lib/Interfaces/SideEffectInterfaces.cpp
@@ -306,6 +306,13 @@ bool mlir::wouldOpBeTriviallyDead(Operation *op) {
   return wouldOpBeTriviallyDeadImpl(op);
 }
 
+bool mlir::hasOnlyReadEffect(Operation *op) {
+  if (auto memEffects = dyn_cast<MemoryEffectOpInterface>(op)) {
+    return memEffects.onlyHasEffect<MemoryEffects::Read>();
+  }
+  return false;
+}
+
 bool mlir::isMemoryEffectFree(Operation *op) {
   if (auto memInterface = dyn_cast<MemoryEffectOpInterface>(op)) {
     if (!memInterface.hasNoEffect())
diff --git a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
index 7460746934a78c..1bdc74dc2a170a 100644
--- a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
@@ -56,48 +56,117 @@ static bool canBeHoisted(Operation *op,
       op, [&](OpOperand &operand) { return definedOutside(operand.get()); });
 }
 
+static bool dependsOnGuarded(Operation *op,
+                             function_ref<bool(OpOperand &)> condition) {
+  auto walkFn = [&](Operation *child) {
+    for (OpOperand &operand : child->getOpOperands()) {
+      if (!condition(operand))
+        return WalkResult::interrupt();
+    }
+    return WalkResult::advance();
+  };
+  return op->walk(walkFn).wasInterrupted();
+}
+
+static bool dependsOnGuarded(Operation *op,
+                             function_ref<bool(Value)> definedOutsideGuard) {
+  return dependsOnGuarded(op, [&](OpOperand &operand) {
+    return definedOutsideGuard(operand.get());
+  });
+}
+
+static bool loopSideEffectFreeOrHasOnlyReadEffect(Operation *loop) {
+  for (auto &region : loop->getRegions()) {
+    for (auto &block : region.getBlocks()) {
+      for (Operation &op : block.getOperations()) {
+        if (!isMemoryEffectFree(&op) && !hasOnlyReadEffect(&op))
+          return false;
+      }
+    }
+  }
+  return true;
+}
+
 size_t mlir::moveLoopInvariantCode(
     ArrayRef<Region *> regions,
     function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
     function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
-    function_ref<void(Operation *, Region *)> moveOutOfRegion) {
+    function_ref<FailureOr<std::pair<Operation *, Region *>>()> wrapInGuard,
+    function_ref<void(Operation *, Region *)> moveOutOfRegion,
+    function_ref<LogicalResult()> unwrapGuard) {
   size_t numMoved = 0;
 
   for (Region *region : regions) {
     LLVM_DEBUG(llvm::dbgs() << "Original loop:\n"
                             << *region->getParentOp() << "\n");
 
+    auto loopSideEffectFreeOrHasOnlyReadSideEffect =
+        loopSideEffectFreeOrHasOnlyReadEffect(region->getParentOp());
+
+    size_t numMovedWithoutGuard = 0;
+
+    FailureOr<std::pair<Operation *, Region *>> ifOpAndRegion = wrapInGuard();
+    Region *loopRegion = region;
+    auto isLoopWrapped = false;
+    if (succeeded(ifOpAndRegion)) {
+      loopRegion = ifOpAndRegion->second;
+      isLoopWrapped = true;
+    }
+
     std::queue<Operation *> worklist;
     // Add top-level operations in the loop body to the worklist.
-    for (Operation &op : region->getOps())
+    for (Operation &op : loopRegion->getOps())
       worklist.push(&op);
 
     auto definedOutside = [&](Value value) {
-      return isDefinedOutsideRegion(value, region);
+      return isDefinedOutsideRegion(value, loopRegion);
+    };
+
+    auto definedOutsideGuard = [&](Value value) {
+      return isDefinedOutsideRegion(value, loopRegion->getParentRegion());
     };
 
     while (!worklist.empty()) {
       Operation *op = worklist.front();
       worklist.pop();
       // Skip ops that have already been moved. Check if the op can be hoisted.
-      if (op->getParentRegion() != region)
+      if (op->getParentRegion() != loopRegion)
         continue;
 
       LLVM_DEBUG(llvm::dbgs() << "Checking op: " << *op << "\n");
-      if (!shouldMoveOutOfRegion(op, region) ||
+
+      if (!shouldMoveOutOfRegion(op, loopRegion) ||
           !canBeHoisted(op, definedOutside))
         continue;
+      // Can only hoist pure ops (side-effect free) when there is an op with
+      // write side effects in the loop
+      if (!loopSideEffectFreeOrHasOnlyReadSideEffect && !isMemoryEffectFree(op))
+        continue;
 
       LLVM_DEBUG(llvm::dbgs() << "Moving loop-invariant op: " << *op << "\n");
-      moveOutOfRegion(op, region);
+
+      auto moveWithoutGuard = isMemoryEffectFree(op) &&
+                              !dependsOnGuarded(op, definedOutsideGuard) &&
+                              isLoopWrapped;
+      numMovedWithoutGuard += moveWithoutGuard;
+
+      moveOutOfRegion(op, moveWithoutGuard ? loopRegion->getParentRegion()
+                                           : loopRegion);
       ++numMoved;
 
       // Since the op has been moved, we need to check its users within the
       // top-level of the loop body.
       for (Operation *user : op->getUsers())
-        if (user->getParentRegion() == region)
+        if (user->getParentRegion() == loopRegion)
           worklist.push(user);
     }
+
+    // Unwrap the loop if it was wrapped but no ops were moved in the guard.
+    if (isLoopWrapped && numMovedWithoutGuard == numMoved) {
+      auto tripCountCheckUnwrapped = unwrapGuard();
+      if (failed(tripCountCheckUnwrapped))
+        llvm_unreachable("Should not fail unwrapping trip-count check");
+    }
   }
 
   return numMoved;
@@ -106,13 +175,18 @@ size_t mlir::moveLoopInvariantCode(
 size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) {
   return moveLoopInvariantCode(
       loopLike.getLoopRegions(),
-      [&](Value value, Region *) {
-        return loopLike.isDefinedOutsideOfLoop(value);
+      [&](Value value, Region *region) {
+        return !region->isAncestor(value.getParentRegion());
       },
       [&](Operation *op, Region *) {
-        return isMemoryEffectFree(op) && isSpeculatable(op);
+        return isSpeculatable(op) &&
+               (isMemoryEffectFree(op) || hasOnlyReadEffect(op));
+      },
+      [&]() { return loopLike.wrapInTripCountCheck(); },
+      [&](Operation *op, Region *region) {
+        op->moveBefore(region->getParentOp());
       },
-      [&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); });
+      [&]() { return loopLike.unwrapTripCountCheck(); });
 }
 
 namespace {
diff --git a/mlir/test/Transforms/loop-invariant-code-motion.mlir b/mlir/test/Transforms/loop-invariant-code-motion.mlir
index e4c423ce7052bf..6f5cc60c59252c 100644
--- a/mlir/test/Transforms/loop-invariant-code-motion.mlir
+++ b/mlir/test/Transforms/loop-invariant-code-motion.mlir
@@ -593,6 +593,111 @@ func.func @test_recursively_speculatable_op_failure(%lb: index, %ub: index, %ste
   return
 }
 
+// CHECK-LABEL: test_speculatable_op_with_read_side_effect_success
+func.func @test_speculatable_op_with_read_side_effect_success(%lb: index, %ub: index, %step: index) -> i32 {
+  // CHECK: arith.subi
+  // CHECK: arith.ceildivsi
+  // CHECK: arith.constant
+  // CHECK: arith.cmpi
+  // CHECK-NEXT: test.always_speculatable_op
+  // CHECK-NEXT: scf.if
+  // CHECK-NEXT: test.speculatable_op_with_memread
+  // CHECK-NEXT: scf.for
+  // CHECK-NOT: test.always_speculatable_op
+  // CHECK-NOT: test.speculatable_op_with_memread
+  %cst_0 = arith.constant 0 : i32
+  %cst_42 = arith.constant dense<42> : tensor<64xi32>
+  %ind_42 = arith.constant 42 : index
+  %sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
+    %always = "test.always_speculatable_op"() : () -> i32
+    %always_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
+    %i_cast = arith.index_cast %i: index to i32
+    %i_sum = arith.addi %acc, %i_cast : i32
+    %test_sum = arith.addi %i_sum, %always_read : i32
+    scf.yield %test_sum : i32
+  }
+  return %sum_result : i32
+}
+
+// CHECK-LABEL: test_speculatable_op_with_read_side_effect_success_with_dependents
+func.func @test_speculatable_op_with_read_side_effect_success_with_dependents(%lb: index, %ub: index, %step: index) -> i32 {
+  // CHECK: arith.subi
+  // CHECK: arith.ceildivsi
+  // CHECK: arith.constant
+  // CHECK: arith.cmpi
+  // CHECK-NEXT: test.always_speculatable_op
+  // CHECK-NEXT: scf.if
+  // CHECK-NEXT: test.speculatable_op_with_memread
+  // CHECK-NEXT: arith.addi
+  // CHECK-NEXT: scf.for
+  // CHECK-NOT: test.always_speculatable_op
+  // CHECK-NOT: test.speculatable_op_with_memread
+  %cst_0 = arith.constant 0 : i32
+  %cst_42 = arith.constant dense<42> : tensor<64xi32>
+  %ind_42 = arith.constant 42 : index
+  %sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
+    %always = "test.always_speculatable_op"() : () -> i32
+    %always_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
+    %add = arith.addi %always_read, %cst_0 : i32
+    %i_cast = arith.index_cast %i: index to i32
+    %i_sum = arith.addi %acc, %i_cast : i32
+    %test_sum = arith.addi %i_sum, %add : i32
+    scf.yield %test_sum : i32
+  }
+  return %sum_result : i32
+}
+
+// CHECK-LABEL: test_speculatable_op_with_read_side_effect_failure_due_to_write
+func.func @test_speculatable_op_with_read_side_effect_failure_due_to_write(%lb: index, %ub: index, %step: index) -> i32 {
+  // CHECK: test.always_speculatable_op
+  // CHECK-NEXT: scf.for
+  // CHECK-NOT: test.always_speculatable_op
+  // CHECK: test.speculatable_op_with_memread
+  // CHECK: test.speculatable_op_with_memwrite
+  %cst_0 = arith.constant 0 : i32
+  %cst_42 = arith.constant dense<42> : tensor<64xi32>
+  %ind_42 = arith.constant 42 : index
+  %sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
+    %always = "test.always_speculatable_op"() : () -> i32
+    %always_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
+    %i_cast = arith.index_cast %i: index to i32
+    %i_sum = arith.addi %acc, %i_cast : i32
+    %test_sum = arith.addi %i_sum, %always_read : i32
+    %always_write = "test.speculatable_op_with_memwrite"(%cst_42) : (tensor<64xi32>) -> i32
+    scf.yield %test_sum : i32
+  }
+  return %sum_result : i32
+}
+
+// CHECK-LABEL: test_speculatable_op_with_read_side_effect_failure_due_to_nested_write
+func.func @test_speculatable_op_with_read_side_effect_failure_due_to_nested_write(%lb: index, %ub: index, %step: index) -> i32 {
+  // CHECK: test.always_speculatable_op
+  // CHECK-NEXT: scf.for
+  // CHECK-NOT: test.always_speculatable_op
+  // CHECK: test.speculatable_op_with_memread
+  // CHECK: scf.for
+  // CHECK: scf.if
+  // CHECK: test.speculatable_op_with_memwrite
+  %cst_0 = arith.constant 0 : i32
+  %cst_42 = arith.constant dense<42> : tensor<64xi32>
+  %ind_42 = arith.constant 42 : index
+  %sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
+    %always = "test.always_speculatable_op"() : () -> i32
+    %always_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
+    %i_cast = arith.index_cast %i: index to i32
+    %i_sum = arith....
[truncated]

@zero9178 zero9178 requested review from kuhar and qedawkins December 17, 2024 21:09
@htyu
Copy link
Contributor

htyu commented Dec 17, 2024

This is needed by Triton to address pytorch/pytorch#134535, cc @ThomasRaoux

Copy link
Contributor

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't looked at everything in details but overall it seems a bit too ad hoc and breaks the existing interfaces/separation of concerns

mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/SCF/IR/SCF.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/SCF/IR/SCF.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/SCF/IR/SCF.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/SCF/IR/SCF.cpp Outdated Show resolved Hide resolved
mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp Outdated Show resolved Hide resolved
Copy link
Contributor

@cxy-1993 cxy-1993 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haven't look into details, but i believe we should reconsider the whole approach. The interface design and the modifications to the pass feel like overkill.

mlir/lib/Dialect/SCF/IR/SCF.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/SCF/IR/SCF.cpp Outdated Show resolved Hide resolved

IRRewriter rewriter(ifOp->getContext());
OpBuilder::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPoint(ifOp);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This requires that this function be called immediately after wrapInTripCountCheck. How can this be guaranteed?

Copy link
Contributor Author

@ardaunal ardaunal Dec 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the check here.

This checks if this region is wrapped and there is no op that needs to be guarded when being hoisted, then it unwraps. This function is called only then.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the reply, I'll look into details later.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Functions implemented by interfaces should be callable from anywhere. If there is a requirement for functions of multiple interfaces to be called in a specific order, it is recommended not to use interfaces.

mlir/lib/Dialect/SCF/IR/SCF.cpp Outdated Show resolved Hide resolved
mlir/lib/Interfaces/SideEffectInterfaces.cpp Outdated Show resolved Hide resolved
mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp Outdated Show resolved Hide resolved
mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp Outdated Show resolved Hide resolved
@ardaunal ardaunal force-pushed the speculative-licm-with-trip-count-check branch from 2b57fa2 to 0e596fd Compare December 19, 2024 22:55
// Otherwise, if the op does not implement the memory effect interface and
// it does not have recursive side effects, then it cannot be known that the
// op is moveable.
return false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implementation is not robust and can hardly handle any operations with regions, because their terminators are inherently memory effect free.

@@ -395,6 +395,60 @@ std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {

std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }

FailureOr<std::pair<Operation *, Region *>> ForOp::wrapInTripCountCheck() {

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strip blank lines


IRRewriter rewriter(ifOp->getContext());
OpBuilder::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPoint(ifOp);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Functions implemented by interfaces should be callable from anywhere. If there is a requirement for functions of multiple interfaces to be called in a specific order, it is recommended not to use interfaces.


LLVM_DEBUG(llvm::dbgs() << "Moving loop-invariant op: " << *op << "\n");
moveOutOfRegion(op, region);

auto moveWithoutGuard = isMemoryEffectFree(op) &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why auto

auto tripCountCheckUnwrapped = unwrapGuard();
if (failed(tripCountCheckUnwrapped))
llvm_unreachable("Should not fail unwrapping trip-count check");
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is unwarp part necessary?Can we use canonicalize to achieve that?

@ardaunal
Copy link
Contributor Author

ardaunal commented Jan 7, 2025

I changed the approach as we discussed on Speculative LICM?.

Thanks for the discussion! @ThomasRaoux, @cxy-1993, @htyu

Following is different from the initial change:

  • Loop is no longer wrapped within a guard.
  • Ops with only read side effect are hoisted with a guard. Else statement of this guard has the ub.poison op having the same type(s) with the op being hoisted.
  • Pure ops are hoisted without a guard unless an op was hoisted with a guard before. Otherwise, pure op is hoisted with a guard. This is needed not to have interleaving branches such as:
  func.func @test_speculatable_op_with_read_side_effect_success_with_dependents(%arg0: index, %arg1: index, %arg2: index) -> i32 {
    %c0_i32 = arith.constant 0 : i32
    %cst = arith.constant dense<42> : tensor<64xi32>
    %c42 = arith.constant 42 : index
    %0 = "test.always_speculatable_op"() : () -> i32
    %1 = arith.cmpi ult, %arg0, %arg1 : index
    %2 = scf.if %1 -> (i32) {
      %8 = "test.speculatable_op_with_memread"(%cst, %c42) : (tensor<64xi32>, index) -> i32
      scf.yield %8 : i32
    } else {
      %8 = ub.poison : i32
      scf.yield %8 : i32
    }
    %3 = arith.addi %0, %2 : i32
    %4 = arith.cmpi ult, %arg0, %arg1 : index
    %5 = scf.if %4 -> (i32) {
      %8 = "test.speculatable_op_with_memread"(%cst, %c42) : (tensor<64xi32>, index) -> i32
      scf.yield %8 : i32
    } else {
      %8 = ub.poison : i32
      scf.yield %8 : i32
    }
    %6 = arith.addi %3, %5 : i32
    %7 = scf.for %arg3 = %arg0 to %arg1 step %arg2 iter_args(%arg4 = %c0_i32) -> (i32) {
      %8 = arith.index_cast %arg3 : index to i32
      %9 = arith.addi %6, %8 : i32
      scf.yield %9 : i32
    }
    return %7 : i32
  }
}

so that CSE and canonicalizer can do their job to get the following instead:

  func.func @test_speculatable_op_with_read_side_effect_success_with_dependents(%arg0: index, %arg1: index, %arg2: index) -> i32 {
    %0 = ub.poison : i32
    %c0_i32 = arith.constant 0 : i32
    %cst = arith.constant dense<42> : tensor<64xi32>
    %c42 = arith.constant 42 : index
    %1 = "test.always_speculatable_op"() : () -> i32
    %2 = arith.cmpi ult, %arg0, %arg1 : index
    %3 = scf.if %2 -> (i32) {
      %5 = "test.speculatable_op_with_memread"(%cst, %c42) : (tensor<64xi32>, index) -> i32
      %6 = arith.addi %1, %5 : i32
      %7 = "test.speculatable_op_with_memread"(%cst, %c42) : (tensor<64xi32>, index) -> i32
      %8 = arith.addi %6, %7 : i32
      scf.yield %8 : i32
    } else {
      scf.yield %0 : i32
    }
    %4 = scf.for %arg3 = %arg0 to %arg1 step %arg2 iter_args(%arg4 = %c0_i32) -> (i32) {
      %5 = arith.index_cast %arg3 : index to i32
      %6 = arith.addi %3, %5 : i32
      scf.yield %6 : i32
    }
    return %4 : i32
  }
}
  • There is only one new interface function moveOutOfLoopWithGuard which is implemented by scf.for for now. Implementation for affine.for should be similar.

@ardaunal ardaunal requested a review from cxy-1993 January 7, 2025 00:26
@ardaunal ardaunal changed the title Enable LICM for ops with only read side effects in scf.for [mlir] Enable LICM for ops with only read side effects in scf.for Jan 13, 2025
@ardaunal
Copy link
Contributor Author

Ping

1 similar comment
@ardaunal
Copy link
Contributor Author

Ping

mlir/lib/Interfaces/SideEffectInterfaces.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/SCF/IR/SCF.cpp Outdated Show resolved Hide resolved
@kuhar
Copy link
Member

kuhar commented Jan 22, 2025

@qedawkins IIRC, you wanted to implement (or even implemented) something similar in the past (except for guarding non-reads too)?

@qedawkins
Copy link
Contributor

I added a simple check to restrict LICM to loops with >= 1 trip count (even for speculatable ops): https://github.com/iree-org/iree/blob/278e63ad7fb790629c329ed0e12f39940ef75916/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp#L485

The idea was that even side effect free ops could still have execution cost so hoisting wasn't safe (without a guard). It wasn't a very sophisticated check and doesn't ever generate a guard. The direction here is interesting for sure!

@ardaunal ardaunal requested a review from kuhar January 28, 2025 23:43
Copy link
Contributor

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have any objection to this PR but I find the interface a bit leaky and the transformation a bit opinionated for it to be live here.
I'm not actively maintaining this part of the code so if someone who does could weight in that would be better.
Otherwise like I said I don't any objection of this going in.

@ardaunal
Copy link
Contributor Author

Pinging @cxy-1993

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:scf mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants