Skip to content

Commit

Permalink
AIRFuseChannels: Fixup false assumption that loop unpeeling shall onl…
Browse files Browse the repository at this point in the history
…y happen at the outer-most loop (#880)

* Fixup typo and vs or

* Fixup the false assumption that unpeeling only happens at the outermost loop in a perfect loop nest

* Remove unused vars

* Add mlir ir test showing channel fusion at the innermost dimension
  • Loading branch information
erwei-xilinx authored Jan 24, 2025
1 parent 82436ea commit 386675b
Show file tree
Hide file tree
Showing 2 changed files with 201 additions and 52 deletions.
92 changes: 40 additions & 52 deletions mlir/lib/Transform/AIRDependencyScheduleOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3750,6 +3750,27 @@ class AIRFuseChannels
return true;
}

// Find the first control loop that is mismatching. Assume longer_loop_nest is
// one-loop more than shorter_loop_nest.
scf::ForOp findMismatchControlLoop(std::vector<Block *> longer_loop_nest,
std::vector<Block *> shorter_loop_nest) {
if (longer_loop_nest.size() - shorter_loop_nest.size() != 1)
return scf::ForOp();
scf::ForOp mismatchScfFor = scf::ForOp();
unsigned index = 0;
for (unsigned i = 0; i < shorter_loop_nest.size(); i++) {
auto aLoop = shorter_loop_nest[i];
auto bLoop = longer_loop_nest[index++];
if (!areEquivalentControlLoops(aLoop, bLoop)) {
mismatchScfFor = dyn_cast<scf::ForOp>(bLoop->getParentOp());
bLoop = longer_loop_nest[index++];
if (!areEquivalentControlLoops(aLoop, bLoop))
return scf::ForOp();
}
}
return mismatchScfFor;
}

std::tuple<bool, std::string>
checkIfTemporalMergeableByScfForImpl(std::vector<Block *> a_loop_nest,
std::vector<Block *> b_loop_nest) {
Expand All @@ -3758,57 +3779,22 @@ class AIRFuseChannels
std::tuple<bool, std::string> mergeableToUB = {true, "UB"};
if (std::abs((int)a_loop_nest.size() - (int)b_loop_nest.size()) != 1)
return notMergeable;
unsigned max_loop_nest_count =
std::max(a_loop_nest.size(), b_loop_nest.size());
// Skip over the unequal scf.for loop, and check equality for the rest of
// the loops first.
int outermostScfFor = -1;
for (unsigned i = 0; i < max_loop_nest_count; i++) {
unsigned scfForCount = 0;
if ((i < a_loop_nest.size()) &&
isa<scf::ForOp>(a_loop_nest[i]->getParentOp()))
scfForCount++;
if ((i < b_loop_nest.size()) &&
isa<scf::ForOp>(b_loop_nest[i]->getParentOp()))
scfForCount++;
if (scfForCount == 1)
outermostScfFor = (int)i;
}
if (outermostScfFor < 0)
return notMergeable;
SmallVector<unsigned> controlLoopIndices;
for (int i = 0; i < (int)max_loop_nest_count; i++)
if (i != outermostScfFor)
controlLoopIndices.push_back(i);
// TODO: Assuming a_loop_nest is before b_loop_nest. Always true? TODO:
// formalize using async dep.
unsigned index = 0;
if (a_loop_nest.size() < b_loop_nest.size()) {
for (auto i : controlLoopIndices) {
if (!areEquivalentControlLoops(a_loop_nest[index++], b_loop_nest[i]))
return notMergeable;
}
// Check if the skipped scf.for loop has LB >= 1. This is a sign of
// peeling, indicating opportunity of merge by unpeeling into LB.
auto outerMostScfFor =
dyn_cast<scf::ForOp>(b_loop_nest[outermostScfFor]->getParentOp());
assert(outerMostScfFor);
if (auto constLB = getConstantIntValue(outerMostScfFor.getLowerBound()))
auto mismatchScfFor = findMismatchControlLoop(b_loop_nest, a_loop_nest);
if (!mismatchScfFor)
return notMergeable;
if (auto constLB = getConstantIntValue(mismatchScfFor.getLowerBound()))
if (*constLB < 1)
return notMergeable;
// Merge by unpeeling into LB.
return mergeableToLB;
} else {
for (auto i : controlLoopIndices) {
if (!areEquivalentControlLoops(a_loop_nest[i], b_loop_nest[index++]))
return notMergeable;
}
auto mismatchScfFor = findMismatchControlLoop(a_loop_nest, b_loop_nest);
if (!mismatchScfFor)
return notMergeable;
// Merge by unpeeling into UB.
[[maybe_unused]] auto outerMostScfFor =
dyn_cast<scf::ForOp>(a_loop_nest[outermostScfFor]->getParentOp());
assert(outerMostScfFor);
return mergeableToUB;
}
return mergeableToLB;
}
std::tuple<bool, std::string>
checkIfTemporalMergeableImpl(std::vector<Block *> a_loop_nest,
Expand Down Expand Up @@ -4112,7 +4098,7 @@ class AIRFuseChannels
b->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
if (aHierSym && bHierSym && aHierSym.str() == bHierSym.str())
return true;
} else if (isa<affine::AffineIfOp>(a) || isa<affine::AffineIfOp>(b)) {
} else if (isa<affine::AffineIfOp>(a) && isa<affine::AffineIfOp>(b)) {
if (a == b)
return false; // Sharing the same affine.if means spatially parallel
// ops. Cannot merge by for loop (i.e. in time).
Expand Down Expand Up @@ -4177,22 +4163,24 @@ class AIRFuseChannels
void mergeChannelOpsTemporally(air::ChannelInterface a,
air::ChannelInterface b,
std::string mergeByLBOrUB) {
scf::ForOp parentForOp = a->getParentOfType<scf::ForOp>();
if (!parentForOp)
scf::ForOp mismatchScfFor =
findMismatchControlLoop(getParentLoopNest(a), getParentLoopNest(b));
if (!mismatchScfFor)
mismatchScfFor =
findMismatchControlLoop(getParentLoopNest(b), getParentLoopNest(a));
if (!mismatchScfFor)
return;
while (parentForOp && parentForOp->getParentOfType<scf::ForOp>()) {
parentForOp = parentForOp->getParentOfType<scf::ForOp>();
}
OpBuilder builder(parentForOp);

OpBuilder builder(mismatchScfFor);
if (mergeByLBOrUB == "LB") {
int originalLB = *getConstantIntValue(parentForOp.getLowerBound());
int originalLB = *getConstantIntValue(mismatchScfFor.getLowerBound());
assert(originalLB > 0);
int currLB = originalLB;
if (a->hasAttr("setLB"))
currLB = a->getAttrOfType<IntegerAttr>("setLB").getInt();
a->setAttr("setLB", builder.getI32IntegerAttr(currLB - 1));
} else if (mergeByLBOrUB == "UB") {
int originalUB = *getConstantIntValue(parentForOp.getUpperBound());
int originalUB = *getConstantIntValue(mismatchScfFor.getUpperBound());
int currUB = originalUB;
if (a->hasAttr("setUB"))
currUB = a->getAttrOfType<IntegerAttr>("setUB").getInt();
Expand Down
161 changes: 161 additions & 0 deletions mlir/test/Transform/AIRDependencyScheduleOpt/fuse_channels.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1569,3 +1569,164 @@ module {
return
}
}

// -----

// Fusing at the inner-most loop in the control loop nest.

// CHECK-LABEL: @func12
// CHECK: %[[FOR0:.*]] = scf.for
// CHECK: %[[FOR1:.*]] = scf.for
// CHECK: %[[FOR2:.*]] = scf.for %{{.*}} = %c0{{.*}} to %c16{{.*}} step %c1{{.*}}
// CHECK: %[[PUT0:.*]] = air.channel.put async{{.*}}@channel_8
// CHECK: scf.yield %[[PUT0]] : !air.async.token
// CHECK: }
// CHECK: scf.yield %[[FOR2]] : !air.async.token
// CHECK: }
// CHECK: scf.yield %[[FOR1]] : !air.async.token
// CHECK: }
// CHECK: %[[FOR0:.*]] = scf.for
// CHECK: %[[FOR1:.*]] = scf.for
// CHECK: %[[FOR2:.*]] = scf.for %{{.*}} = %c0{{.*}} to %c16{{.*}} step %c1{{.*}}
// CHECK: %[[PUT0:.*]] = air.channel.put async{{.*}}@channel_9
// CHECK: scf.yield %[[PUT0]] : !air.async.token
// CHECK: }
// CHECK: scf.yield %[[FOR2]] : !air.async.token
// CHECK: }
// CHECK: scf.yield %[[FOR1]] : !air.async.token
// CHECK: }
// CHECK: air.herd @herd_0
// CHECK: %[[FOR0:.*]] = scf.for
// CHECK: %[[FOR1:.*]] = scf.for
// CHECK: affine.if
// CHECK-NEXT: air.channel.get async{{.*}}@channel_8
// CHECK-NEXT: affine.yield
// CHECK-NEXT: else
// CHECK-NEXT: air.channel.get async{{.*}}@channel_9
// CHECK-NEXT: affine.yield
// CHECK-NEXT: }
// CHECK: scf.for %{{.*}} = %c1{{.*}} to %c15{{.*}} step %c1{{.*}}
// CHECK: affine.if
// CHECK-NEXT: air.channel.get async{{.*}}@channel_8
// CHECK-NEXT: affine.yield
// CHECK-NEXT: else
// CHECK-NEXT: air.channel.get async{{.*}}@channel_9
// CHECK-NEXT: affine.yield
// CHECK-NEXT: }
// CHECK: }
// CHECK: affine.if
// CHECK-NEXT: air.channel.get async{{.*}}@channel_8
// CHECK-NEXT: affine.yield
// CHECK-NEXT: else
// CHECK-NEXT: air.channel.get async{{.*}}@channel_9
// CHECK-NEXT: affine.yield
// CHECK-NEXT: }

#set = affine_set<()[s0, s1] : (s0 == 0, s1 >= 0, -s1 + 3 >= 0)>
module {
air.channel @channel_0 [1, 1] {broadcast_shape = [1, 4]}
air.channel @channel_1 [1, 1] {broadcast_shape = [1, 4]}
air.channel @channel_8 [1, 1] {broadcast_shape = [1, 4]}
air.channel @channel_9 [1, 1] {broadcast_shape = [1, 4]}
air.channel @channel_16 [1, 1] {broadcast_shape = [1, 4]}
air.channel @channel_17 [1, 1] {broadcast_shape = [1, 4]}
func.func @func12(%arg0: memref<1x1x4x8x4x8xbf16, 2 : i32>, %arg1: memref<8x16x32x32xbf16, 1 : i32>, %arg2: memref<8x16x32x32xbf16, 1 : i32>) {
%c128 = arith.constant 128 : index
%c1024 = arith.constant 1024 : index
%c15 = arith.constant 15 : index
%c4 = arith.constant 4 : index
%c1 = arith.constant 1 : index
%c16384 = arith.constant 16384 : index
%c32 = arith.constant 32 : index
%c8 = arith.constant 8 : index
%c0 = arith.constant 0 : index
%0 = air.wait_all async
%1 = scf.for %arg3 = %c0 to %c8 step %c4 iter_args(%arg4 = %0) -> (!air.async.token) {
%8 = scf.for %arg5 = %c0 to %c8 step %c4 iter_args(%arg6 = %arg4) -> (!air.async.token) {
%9 = air.channel.put async [%arg6] @channel_0[] (%arg2[%arg3, %c0, %c0, %c0, %c0, %c0] [%c1, %c1, %c4, %c8, %c4, %c8] [%c16384, %c1024, %c8, %c128, %c32, %c1]) {id = 6 : i32} : (memref<8x16x32x32xbf16, 1 : i32>)
scf.yield %9 : !air.async.token
}
scf.yield %8 : !air.async.token
}
%2 = scf.for %arg3 = %c0 to %c8 step %c4 iter_args(%arg4 = %0) -> (!air.async.token) {
%8 = scf.for %arg5 = %c0 to %c8 step %c4 iter_args(%arg6 = %arg4) -> (!air.async.token) {
%9 = air.channel.put async [%arg6] @channel_1[] (%arg2[%arg3, %c0, %c0, %c0, %c0, %c0] [%c1, %c1, %c4, %c8, %c4, %c8] [%c16384, %c1024, %c8, %c128, %c32, %c1]) {id = 7 : i32} : (memref<8x16x32x32xbf16, 1 : i32>)
scf.yield %9 : !air.async.token
}
scf.yield %8 : !air.async.token
}
%3 = scf.for %arg3 = %c0 to %c8 step %c4 iter_args(%arg4 = %0) -> (!air.async.token) {
%8 = scf.for %arg5 = %c0 to %c8 step %c4 iter_args(%arg6 = %arg4) -> (!air.async.token) {
%9 = scf.for %arg7 = %c1 to %c15 step %c1 iter_args(%arg8 = %arg6) -> (!air.async.token) {
%10 = air.channel.put async [%arg8] @channel_8[] (%arg2[%arg3, %arg7, %c0, %c0, %c0, %c0] [%c1, %c1, %c4, %c8, %c4, %c8] [%c16384, %c1024, %c8, %c128, %c32, %c1]) {id = 22 : i32} : (memref<8x16x32x32xbf16, 1 : i32>)
scf.yield %10 : !air.async.token
}
scf.yield %9 : !air.async.token
}
scf.yield %8 : !air.async.token
}
%4 = scf.for %arg3 = %c0 to %c8 step %c4 iter_args(%arg4 = %0) -> (!air.async.token) {
%8 = scf.for %arg5 = %c0 to %c8 step %c4 iter_args(%arg6 = %arg4) -> (!air.async.token) {
%9 = scf.for %arg7 = %c1 to %c15 step %c1 iter_args(%arg8 = %arg6) -> (!air.async.token) {
%10 = air.channel.put async [%arg8] @channel_9[] (%arg2[%arg3, %arg7, %c0, %c0, %c0, %c0] [%c1, %c1, %c4, %c8, %c4, %c8] [%c16384, %c1024, %c8, %c128, %c32, %c1]) {id = 23 : i32} : (memref<8x16x32x32xbf16, 1 : i32>)
scf.yield %10 : !air.async.token
}
scf.yield %9 : !air.async.token
}
scf.yield %8 : !air.async.token
}
%5 = air.herd @herd_0 async tile (%arg3, %arg4) in (%arg5=%c4, %arg6=%c4) args(%arg7=%arg0) : memref<1x1x4x8x4x8xbf16, 2 : i32> attributes {id = 5 : i32, link_with = "mm.o"} {
%c15_0 = arith.constant 15 : index
%c1_1 = arith.constant 1 : index
%c4_2 = arith.constant 4 : index
%c0_3 = arith.constant 0 : index
%c8_4 = arith.constant 8 : index
%8 = air.wait_all async
%9 = scf.for %arg8 = %c0_3 to %c8_4 step %c4_2 iter_args(%arg9 = %8) -> (!air.async.token) {
%10 = scf.for %arg10 = %c0_3 to %c8_4 step %c4_2 iter_args(%arg11 = %arg9) -> (!air.async.token) {
%11 = affine.if #set()[%arg3, %arg4] -> !air.async.token {
%14 = air.channel.get async @channel_0[%arg3, %arg4] (%arg7[] [] []) {id = 14 : i32} : (memref<1x1x4x8x4x8xbf16, 2 : i32>)
affine.yield %14 : !air.async.token
} else {
%14 = air.channel.get async @channel_1[%arg3, %arg4] (%arg7[] [] []) {id = 15 : i32} : (memref<1x1x4x8x4x8xbf16, 2 : i32>)
affine.yield %14 : !air.async.token
}
scf.for %arg12 = %c1_1 to %c15_0 step %c1_1 {
%14 = affine.if #set()[%arg3, %arg4] -> !air.async.token {
%15 = air.channel.get async @channel_8[%arg3, %arg4] (%arg7[] [] []) {id = 30 : i32} : (memref<1x1x4x8x4x8xbf16, 2 : i32>)
affine.yield %15 : !air.async.token
} else {
%15 = air.channel.get async @channel_9[%arg3, %arg4] (%arg7[] [] []) {id = 31 : i32} : (memref<1x1x4x8x4x8xbf16, 2 : i32>)
affine.yield %15 : !air.async.token
}
}
%12 = affine.if #set()[%arg3, %arg4] -> !air.async.token {
%14 = air.channel.get async @channel_16[%arg3, %arg4] (%arg7[] [] []) {id = 47 : i32} : (memref<1x1x4x8x4x8xbf16, 2 : i32>)
affine.yield %14 : !air.async.token
} else {
%14 = air.channel.get async @channel_17[%arg3, %arg4] (%arg7[] [] []) {id = 48 : i32} : (memref<1x1x4x8x4x8xbf16, 2 : i32>)
affine.yield %14 : !air.async.token
}
%13 = air.wait_all async
scf.yield %13 : !air.async.token
}
scf.yield %10 : !air.async.token
}
}
%6 = scf.for %arg3 = %c0 to %c8 step %c4 iter_args(%arg4 = %0) -> (!air.async.token) {
%8 = scf.for %arg5 = %c0 to %c8 step %c4 iter_args(%arg6 = %arg4) -> (!air.async.token) {
%9 = air.channel.put async [%arg6] @channel_16[] (%arg2[%arg3, %c15, %c0, %c0, %c0, %c0] [%c1, %c1, %c4, %c8, %c4, %c8] [%c16384, %c1024, %c8, %c128, %c32, %c1]) {id = 38 : i32} : (memref<8x16x32x32xbf16, 1 : i32>)
scf.yield %9 : !air.async.token
}
scf.yield %8 : !air.async.token
}
%7 = scf.for %arg3 = %c0 to %c8 step %c4 iter_args(%arg4 = %0) -> (!air.async.token) {
%8 = scf.for %arg5 = %c0 to %c8 step %c4 iter_args(%arg6 = %arg4) -> (!air.async.token) {
%9 = air.channel.put async [%arg6] @channel_17[] (%arg2[%arg3, %c15, %c0, %c0, %c0, %c0] [%c1, %c1, %c4, %c8, %c4, %c8] [%c16384, %c1024, %c8, %c128, %c32, %c1]) {id = 39 : i32} : (memref<8x16x32x32xbf16, 1 : i32>)
scf.yield %9 : !air.async.token
}
scf.yield %8 : !air.async.token
}
return
}
}

0 comments on commit 386675b

Please sign in to comment.