Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir] Enable LICM for ops with only read side effects in scf.for #120302

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def SCF_Dialect : Dialect {
and then lowered to some final target like LLVM or SPIR-V.
}];

let dependentDialects = ["arith::ArithDialect"];
let dependentDialects = ["arith::ArithDialect", "ub::UBDialect"];
}

// Base class for SCF dialect ops.
Expand Down Expand Up @@ -138,7 +138,9 @@ def ForOp : SCF_Op<"for",
["getInitsMutable", "getLoopResults", "getRegionIterArgs",
"getLoopInductionVars", "getLoopLowerBounds", "getLoopSteps",
"getLoopUpperBounds", "getYieldedValuesMutable",
"moveOutOfLoopWithGuard",
"promoteIfSingleIteration", "replaceWithAdditionalYields",
"wrapInTripCountCheck", "unwrapTripCountCheck",
"yieldTiledValuesAndReplace"]>,
AllTypesMatch<["lowerBound", "upperBound", "step"]>,
ConditionallySpeculatable,
Expand Down Expand Up @@ -302,7 +304,7 @@ def ForallOp : SCF_Op<"forall", [
AttrSizedOperandSegments,
AutomaticAllocationScope,
DeclareOpInterfaceMethods<LoopLikeOpInterface,
["getInitsMutable", "getRegionIterArgs", "getLoopInductionVars",
["getInitsMutable", "getRegionIterArgs", "getLoopInductionVars",
"getLoopLowerBounds", "getLoopUpperBounds", "getLoopSteps",
"promoteIfSingleIteration", "yieldTiledValuesAndReplace"]>,
RecursiveMemoryEffects,
Expand Down
12 changes: 12 additions & 0 deletions mlir/include/mlir/Interfaces/LoopLikeInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,18 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
/*methodBody=*/"",
/*defaultImplementation=*/"op->moveBefore($_op);"
>,
InterfaceMethod<[{
Moves the given loop-invariant operation out of the loop with a
trip-count guard.
}],
/*retTy=*/"void",
/*methodName=*/"moveOutOfLoopWithGuard",
/*args=*/(ins "::mlir::Operation *":$op),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return;
}]
>,
InterfaceMethod<[{
Promotes the loop body to its containing block if the loop is known to
have a single iteration. Returns "success" if the promotion was
Expand Down
4 changes: 4 additions & 0 deletions mlir/include/mlir/Interfaces/SideEffectInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,10 @@ bool wouldOpBeTriviallyDead(Operation *op);
/// conditions are satisfied.
bool isMemoryEffectFree(Operation *op);

/// Returns true if the given operation is free of memory effects or has only
/// read effect.
bool isMemoryEffectFreeOrOnlyRead(Operation *op);

/// Returns the side effects of an operation. If the operation has
/// RecursiveMemoryEffects, include all side effects of child operations.
///
Expand Down
16 changes: 10 additions & 6 deletions mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,19 @@ class Value;
/// }
/// }
/// ```
///
/// Users must supply three callbacks.
/// Users must supply four callbacks.
///
/// - `isDefinedOutsideRegion` returns true if the given value is invariant with
/// respect to the given region. A common implementation might be:
/// `value.getParentRegion()->isProperAncestor(region)`.
/// - `shouldMoveOutOfRegion` returns true if the provided operation can be
/// moved of the given region, e.g. if it is side-effect free.
/// - `moveOutOfRegion` moves the operation out of the given region. A common
/// implementation might be: `op->moveBefore(region->getParentOp())`.
/// moved of the given region, e.g. if it is side-effect free or has only read
/// side effects.
/// - `moveOutOfRegionWithoutGuard` moves the operation out of the given region
/// without a guard. A common implementation might be:
/// `op->moveBefore(region->getParentOp())`.
/// - `moveOutOfRegionWithGuard` moves the operation out of the given region
/// with a guard.
///
/// An operation is moved if all of its operands satisfy
/// `isDefinedOutsideRegion` and it satisfies `shouldMoveOutOfRegion`.
Expand All @@ -66,7 +69,8 @@ size_t moveLoopInvariantCode(
ArrayRef<Region *> regions,
function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
function_ref<void(Operation *, Region *)> moveOutOfRegion);
function_ref<void(Operation *)> moveOutOfRegionWithoutGuard,
function_ref<void(Operation *)> moveOutOfRegionWithGuard);

/// Move side-effect free loop invariant code out of a loop-like op using
/// methods provided by the interface.
Expand Down
40 changes: 37 additions & 3 deletions mlir/lib/Dialect/SCF/IR/SCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
Expand Down Expand Up @@ -395,6 +396,40 @@ std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {

std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }

/// Moves the op out of the loop with a guard that checks if the loop has at
/// least one iteration.
void ForOp::moveOutOfLoopWithGuard(Operation *op) {
IRRewriter rewriter(this->getContext());
OpBuilder::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPoint(this->getOperation());
Location loc = this->getLoc();
arith::CmpIOp cmpIOp = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ult, this->getLowerBound(),
this->getUpperBound());
// Create the trip-count check.
scf::YieldOp thenYield;
scf::IfOp ifOp = rewriter.create<scf::IfOp>(
loc, cmpIOp,
[&](OpBuilder &builder, Location loc) {
thenYield = builder.create<scf::YieldOp>(loc, op->getResults());
},
[&](OpBuilder &builder, Location loc) {
SmallVector<Value> poisonResults;
poisonResults.reserve(op->getResults().size());
for (Type type : op->getResults().getTypes()) {
ub::PoisonOp poisonOp =
rewriter.create<ub::PoisonOp>(loc, type, nullptr);
poisonResults.push_back(poisonOp);
}
builder.create<scf::YieldOp>(loc, poisonResults);
});
for (auto [opResult, ifOpResult] :
llvm::zip_equal(op->getResults(), ifOp->getResults()))
rewriter.replaceAllUsesExcept(opResult, ifOpResult, thenYield);
// Move the op into the then block.
rewriter.moveOpBefore(op, thenYield);
}

/// Promotes the loop body of a forOp to its containing block if the forOp
/// it can be determined that the loop has a single iteration.
LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
Expand Down Expand Up @@ -3397,9 +3432,8 @@ ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {

if (functionType.getNumInputs() != operands.size()) {
return parser.emitError(typeLoc)
<< "expected as many input types as operands "
<< "(expected " << operands.size() << " got "
<< functionType.getNumInputs() << ")";
<< "expected as many input types as operands " << "(expected "
<< operands.size() << " got " << functionType.getNumInputs() << ")";
}

// Resolve input operands.
Expand Down
12 changes: 12 additions & 0 deletions mlir/lib/Interfaces/SideEffectInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "mlir/Interfaces/SideEffectInterfaces.h"

#include "mlir/IR/SymbolTable.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"

using namespace mlir;
Expand Down Expand Up @@ -363,6 +364,17 @@ mlir::getEffectsRecursively(Operation *rootOp) {
return effects;
}

bool mlir::isMemoryEffectFreeOrOnlyRead(Operation *op) {
std::optional<SmallVector<MemoryEffects::EffectInstance>> effects =
getEffectsRecursively(op);
if (!effects)
return false;
return llvm::all_of(*effects,
[&](const MemoryEffects::EffectInstance &effect) {
return isa<MemoryEffects::Read>(effect.getEffect());
});
}

bool mlir::isSpeculatable(Operation *op) {
auto conditionallySpeculatable = dyn_cast<ConditionallySpeculatable>(op);
if (!conditionallySpeculatable)
Expand Down
34 changes: 27 additions & 7 deletions mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,18 @@ size_t mlir::moveLoopInvariantCode(
ArrayRef<Region *> regions,
function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
function_ref<void(Operation *, Region *)> moveOutOfRegion) {
function_ref<void(Operation *)> moveOutOfRegionWithoutGuard,
function_ref<void(Operation *)> moveOutOfRegionWithGuard) {
size_t numMoved = 0;

for (Region *region : regions) {
LLVM_DEBUG(llvm::dbgs() << "Original loop:\n"
<< *region->getParentOp() << "\n");

bool anyOpHoistedWithGuard = false;
bool loopSideEffectFreeOrHasOnlyReadSideEffect =
isMemoryEffectFreeOrOnlyRead(region->getParentOp());

std::queue<Operation *> worklist;
// Add top-level operations in the loop body to the worklist.
for (Operation &op : region->getOps())
Expand All @@ -84,12 +89,26 @@ size_t mlir::moveLoopInvariantCode(
continue;

LLVM_DEBUG(llvm::dbgs() << "Checking op: " << *op << "\n");

if (!shouldMoveOutOfRegion(op, region) ||
!canBeHoisted(op, definedOutside))
continue;
// Can only hoist pure ops (side-effect free) when there is an op with
// write and/or unknown side effects in the loop.
if (!loopSideEffectFreeOrHasOnlyReadSideEffect && !isMemoryEffectFree(op))
continue;

LLVM_DEBUG(llvm::dbgs() << "Moving loop-invariant op: " << *op << "\n");
moveOutOfRegion(op, region);
bool moveWithoutGuard = !anyOpHoistedWithGuard && isMemoryEffectFree(op);
if (moveWithoutGuard) {
LLVM_DEBUG(llvm::dbgs() << "Moving loop-invariant op: " << *op
<< " without guard\n");
moveOutOfRegionWithoutGuard(op);
} else {
LLVM_DEBUG(llvm::dbgs()
<< "Moving loop-invariant op: " << *op << " with guard\n");
moveOutOfRegionWithGuard(op);
anyOpHoistedWithGuard = true;
}
++numMoved;

// Since the op has been moved, we need to check its users within the
Expand All @@ -106,13 +125,14 @@ size_t mlir::moveLoopInvariantCode(
size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) {
return moveLoopInvariantCode(
loopLike.getLoopRegions(),
[&](Value value, Region *) {
return loopLike.isDefinedOutsideOfLoop(value);
[&](Value value, Region *region) {
return !region->isAncestor(value.getParentRegion());
},
[&](Operation *op, Region *) {
return isMemoryEffectFree(op) && isSpeculatable(op);
return isSpeculatable(op) && isMemoryEffectFreeOrOnlyRead(op);
},
[&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); });
[&](Operation *op) { loopLike.moveOutOfLoop(op); },
[&](Operation *op) { loopLike.moveOutOfLoopWithGuard(op); });
}

namespace {
Expand Down
144 changes: 144 additions & 0 deletions mlir/test/Transforms/loop-invariant-code-motion.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,150 @@ func.func @test_recursively_speculatable_op_failure(%lb: index, %ub: index, %ste
return
}

// CHECK-LABEL: test_speculatable_op_with_read_side_effect_success
func.func @test_speculatable_op_with_read_side_effect_success(%lb: index, %ub: index, %step: index) -> i32 {
// CHECK: test.always_speculatable_op
// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ult, %[[LB:.*]], %[[UB:.*]] : index
// CHECK-NEXT: scf.if %[[CMP]]
// CHECK-NEXT: test.speculatable_op_with_memread
// CHECK: else
// CHECK-NEXT: ub.poison : i32
// CHECK: scf.for %[[_:.*]] = %[[LB]] to %[[UB]]
// CHECK-NOT: test.always_speculatable_op
// CHECK-NOT: test.speculatable_op_with_memread
%cst_0 = arith.constant 0 : i32
%cst_42 = arith.constant dense<42> : tensor<64xi32>
%ind_42 = arith.constant 42 : index
%sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
%always_speculate = "test.always_speculatable_op"() : () -> i32
%only_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
%i_cast = arith.index_cast %i: index to i32
%add = arith.addi %acc, %i_cast : i32
%sum = arith.addi %add, %only_read : i32
scf.yield %sum : i32
}
return %sum_result : i32
}

// CHECK-LABEL: test_speculatable_op_with_read_side_effect_multiple_result_success
func.func @test_speculatable_op_with_read_side_effect_multiple_result_success(%lb: index, %ub: index, %step: index) -> i32 {
// CHECK: test.always_speculatable_op
// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ult, %[[LB:.*]], %[[UB:.*]] : index
// CHECK-NEXT: scf.if %[[CMP]]
// CHECK-NEXT: test.speculatable_op_with_memread
// CHECK: else
// CHECK-NEXT: ub.poison : i32
// CHECK-NEXT: ub.poison : f32
// CHECK: scf.for %[[_:.*]] = %[[LB]] to %[[UB]]
// CHECK-NOT: test.always_speculatable_op
// CHECK-NOT: test.speculatable_op_with_memread
%cst_0 = arith.constant 0 : i32
%cst_42 = arith.constant dense<42> : tensor<64xi32>
%ind_42 = arith.constant 42 : index
%sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
%always_speculate = "test.always_speculatable_op"() : () -> i32
%only_read:2 = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> (i32, f32)
%i_cast = arith.index_cast %i: index to i32
%add = arith.addi %acc, %i_cast : i32
%sum = arith.addi %add, %only_read#0 : i32
scf.yield %sum : i32
}
return %sum_result : i32
}

// CHECK-LABEL: test_speculatable_op_with_read_side_effect_success_with_dependents
func.func @test_speculatable_op_with_read_side_effect_success_with_dependents(%lb: index, %ub: index, %step: index) -> i32 {
// CHECK: %[[ALWAYS:.*]] = "test.always_speculatable_op"
// CHECK-NEXT: %[[CMP0:.*]] = arith.cmpi ult, %[[LB:.*]], %[[UB:.*]] : index
// CHECK-NEXT: %[[IF0:.*]] = scf.if %[[CMP0]]
// CHECK-NEXT: test.speculatable_op_with_memread
// CHECK: else
// CHECK-NEXT: ub.poison : i32
// CHECK: %[[CMP1:.*]] = arith.cmpi ult, %[[LB]], %[[UB]] : index
// CHECK-NEXT: %[[IF1:.*]] = scf.if %[[CMP1]]
// CHECK-NEXT: arith.addi %[[ALWAYS]], %[[IF0]]
// CHECK: else
// CHECK-NEXT: ub.poison : i32
// CHECK: %[[CMP2:.*]] = arith.cmpi ult, %[[LB]], %[[UB]] : index
// CHECK-NEXT: %[[IF2:.*]] = scf.if %[[CMP2]]
// CHECK-NEXT: test.speculatable_op_with_memread
// CHECK: else
// CHECK-NEXT: ub.poison : i32
// CHECK: %[[CMP3:.*]] = arith.cmpi ult, %[[LB]], %[[UB]] : index
// CHECK-NEXT: %{{.*}} = scf.if %[[CMP3]]
// CHECK-NEXT: arith.addi %[[IF1]], %[[IF2]]
// CHECK: else
// CHECK-NEXT: ub.poison : i32
// CHECK: scf.for %{{.*}} = %[[LB]] to %[[UB]]
// CHECK-NOT: test.always_speculatable_op
// CHECK-NOT: test.speculatable_op_with_memread
%cst_0 = arith.constant 0 : i32
%cst_42 = arith.constant dense<42> : tensor<64xi32>
%ind_42 = arith.constant 42 : index
%sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
%always_speculate = "test.always_speculatable_op"() : () -> i32
%only_read_0 = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
%add_0 = arith.addi %always_speculate, %only_read_0 : i32
%only_read_1 = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
%add_1 = arith.addi %add_0, %only_read_1 : i32
%i_cast = arith.index_cast %i: index to i32
%sum = arith.addi %add_1, %i_cast : i32
scf.yield %sum : i32
}
return %sum_result : i32
}

// CHECK-LABEL: test_speculatable_op_with_read_side_effect_failure_due_to_write
func.func @test_speculatable_op_with_read_side_effect_failure_due_to_write(%lb: index, %ub: index, %step: index) -> i32 {
// CHECK: test.always_speculatable_op
// CHECK-NEXT: scf.for
// CHECK-NOT: test.always_speculatable_op
// CHECK: test.speculatable_op_with_memread
// CHECK: test.speculatable_op_with_memwrite
%cst_0 = arith.constant 0 : i32
%cst_42 = arith.constant dense<42> : tensor<64xi32>
%ind_42 = arith.constant 42 : index
%sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
%always_speculate = "test.always_speculatable_op"() : () -> i32
%only_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
%i_cast = arith.index_cast %i: index to i32
%add = arith.addi %acc, %i_cast : i32
%sum = arith.addi %add, %only_read : i32
%write = "test.speculatable_op_with_memwrite"(%cst_42) : (tensor<64xi32>) -> i32
scf.yield %sum : i32
}
return %sum_result : i32
}

// CHECK-LABEL: test_speculatable_op_with_read_side_effect_failure_due_to_nested_write
func.func @test_speculatable_op_with_read_side_effect_failure_due_to_nested_write(%lb: index, %ub: index, %step: index) -> i32 {
// CHECK: test.always_speculatable_op
// CHECK-NEXT: scf.for
// CHECK-NOT: test.always_speculatable_op
// CHECK: test.speculatable_op_with_memread
// CHECK: scf.for
// CHECK: scf.if
// CHECK: test.speculatable_op_with_memwrite
%cst_0 = arith.constant 0 : i32
%cst_42 = arith.constant dense<42> : tensor<64xi32>
%ind_42 = arith.constant 42 : index
%sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
%always_speculate = "test.always_speculatable_op"() : () -> i32
%only_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
%i_cast = arith.index_cast %i: index to i32
%add = arith.addi %acc, %i_cast : i32
%sum = arith.addi %add, %only_read : i32
scf.for %j = %lb to %ub step %step {
%eq42 = arith.cmpi eq, %j, %ind_42 : index
scf.if %eq42 {
%always_write = "test.speculatable_op_with_memwrite"(%cst_42) : (tensor<64xi32>) -> i32
}
}
scf.yield %sum : i32
}
return %sum_result : i32
}

// -----

func.func @speculate_tensor_dim_unknown_rank_unknown_dim(
Expand Down
Loading