Skip to content

Commit

Permalink
AIRDmaToChannel: Have HoistingAffineIf return LogicalResult (#883)
Browse files Browse the repository at this point in the history
  • Loading branch information
erwei-xilinx authored Jan 25, 2025
1 parent f742830 commit e7f682f
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 39 deletions.
11 changes: 11 additions & 0 deletions mlir/include/air/Util/Util.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,17 @@ getAffineIfNestAndSpatialLoopFromOp(Operation *op,
std::vector<Operation *> &affine_if_nest,
Operation *&spatial_loop);

// Evaluate the condition lower and upper boundaries that the specified op is
// hitting, from an affine if nest. Assumes a rectangular condition bound
// region.
SmallVector<std::pair<int, int>> getRectangularConditionBoundsThroughAffineIfs(
Operation *op, Operation *spatial_loop,
std::vector<Operation *> affine_if_nest);

// Evaluate the integer value of affine set expression if the only symbolic
// identifier is replaced with zero
int evaluateSymbolEqualityInSet(AffineExpr c, MLIRContext *ctx);

struct LinalgTransforms {
static const StringLiteral kLinalgTransformMarker;
};
Expand Down
41 changes: 31 additions & 10 deletions mlir/lib/Transform/AIRDmaToChannel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -531,13 +531,14 @@ static void replaceAIRDmaWithAIRChannelPairs(
internalGetPutVector.push_back(internalGetPut);
}

static void HoistingAffineIf(affine::AffineIfOp op) {
LogicalResult HoistingAffineIf(affine::AffineIfOp op) {
auto ctx = op->getContext();

air::HierarchyInterface hier_op = nullptr;
unsigned int innerMemorySpace = 0;
auto herd = op->getParentOfType<air::HerdOp>();
assert(herd && "affine if op has no air.herdOp as parent");
if (!herd)
return failure();
auto segment = op->getParentOfType<air::SegmentOp>();
if (herd) {
hier_op = dyn_cast<air::HierarchyInterface>(herd.getOperation());
Expand Down Expand Up @@ -631,16 +632,33 @@ static void HoistingAffineIf(affine::AffineIfOp op) {
module_builder.getUnknownLoc(), 0);

// Check for op buffer sizes
assert(internalGetPut.size() == externalGetPut.size());
assert(externalGetPut.size() == dmas.size());
if (internalGetPut.size() != externalGetPut.size())
return failure();
if (externalGetPut.size() != dmas.size())
return failure();

// Clone ops
unsigned dma_index = 0;
for (size_t i = 0; i < dmas.size(); i++) {
for (size_t dma_index = 0; dma_index < dmas.size(); dma_index++) {
// Get mapping for remapped ssa values entering the hoisted scf.parallel
IRMapping remap;
remap.map(herd.getIds()[0], zero_const_op);
remap.map(herd.getIds()[1], zero_const_op);
auto is = dmas[dma_index]
->getAttrOfType<IntegerSetAttr>("broadcast_set")
.getValue();
for (size_t herdDim = 0; herdDim < herd.getNumDims(); herdDim++) {
remap.map(herd.getIds()[herdDim], zero_const_op);
for (unsigned i = 0; i < is.getNumConstraints(); i++) {
if (!is.isEq(i))
continue;
auto c = is.getConstraint(i);
if (!c.isFunctionOfSymbol(herdDim))
continue;
auto constIV = module_builder.create<arith::ConstantIndexOp>(
module_builder.getUnknownLoc(),
air::evaluateSymbolEqualityInSet(c, ctx));
remap.map(herd.getIds()[herdDim], constIV);
}
}

int arg_idx = 0;
if (externalMemrefTy.getMemorySpaceAsInt() == (int)air::MemorySpace::L3 &&
segment) {
Expand Down Expand Up @@ -717,7 +735,6 @@ static void HoistingAffineIf(affine::AffineIfOp op) {
for (auto o : clonedOps)
for (auto d : air::getAsyncDependenciesFromOp(hier_op))
air::addAsyncDependencyIfNew(o, d);
dma_index++;
}

module.walk([&](mlir::Operation *o) {
Expand All @@ -735,6 +752,8 @@ static void HoistingAffineIf(affine::AffineIfOp op) {
for (auto &dma : dmas) {
dma->erase();
}

return success();
}

namespace {
Expand Down Expand Up @@ -1402,7 +1421,9 @@ struct DmaToChannelPass : public air::impl::DmaToChannelBase<DmaToChannelPass> {
f.walk([&](affine::AffineIfOp op) {
if (!op->getParentOfType<affine::AffineIfOp>()) {
// Only hoist top-level affine if op with a nest of if ops
HoistingAffineIf(op);
auto res = HoistingAffineIf(op);
if (failed(res))
signalPassFailure();
}
});
}
Expand Down
33 changes: 4 additions & 29 deletions mlir/lib/Util/Runner/RunnerNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1666,35 +1666,10 @@ class runnerNode {
unsigned getSizeThroughAffineIf(Operation *op, Operation *spatial_loop,
std::vector<Operation *> affine_if_nest) {
unsigned output = 1;
SmallVector<int, 2> lbs_spatial;
SmallVector<int, 2> ubs_spatial;
getSizesFromSpatialLoop(spatial_loop, lbs_spatial, ubs_spatial);

// Walk through affine.if nest (in reverse order through vector)
for (auto it = affine_if_nest.rbegin(); it != affine_if_nest.rend(); ++it) {
auto affine_if = dyn_cast<affine::AffineIfOp>(*it);
// Get then integerset sizes
SmallVector<int, 2> lbs_int = {0, 0};
SmallVector<int, 2> ubs_int = {0, 0};
IntegerSet int_set = affine_if.getIntegerSet();
getSizesFromIntegerSet(affine_if->getContext(), int_set, lbs_int,
ubs_int);
// If found then block containing op
if (affine_if.getThenBlock()->findAncestorOpInBlock(*op)) {
for (unsigned i = 0; i < lbs_int.size(); i++) {
output *= ubs_int[i] - lbs_int[i] + 1;
}
return output;
}
// Else keep going, while updating the spatial sizes wrt else condition
else {
getElseSizesFromAffineIf(lbs_spatial, ubs_spatial, lbs_int, ubs_int);
}
}
// If op isn't in any then blocks in affine.if nest
for (unsigned i = 0; i < lbs_spatial.size(); i++) {
output *= ubs_spatial[i] - lbs_spatial[i] + 1;
}
auto conditionBounds = air::getRectangularConditionBoundsThroughAffineIfs(
op, spatial_loop, affine_if_nest);
for (auto [lbs_int, ubs_int] : conditionBounds)
output *= ubs_int - lbs_int + 1;
return output;
}

Expand Down
56 changes: 56 additions & 0 deletions mlir/lib/Util/Util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,62 @@ Operation *air::getAffineIfNestAndSpatialLoopFromOp(
return parent;
}

// Evaluate the condition lower and upper boundaries that the specified op is
// hitting, from an affine if nest. Assumes a rectangular condition bound
// region.
SmallVector<std::pair<int, int>>
air::getRectangularConditionBoundsThroughAffineIfs(
Operation *op, Operation *spatial_loop,
std::vector<Operation *> affine_if_nest) {
SmallVector<int, 2> lbs_spatial;
SmallVector<int, 2> ubs_spatial;
air::getSizesFromSpatialLoop(spatial_loop, lbs_spatial, ubs_spatial);

// Walk through affine.if nest (in reverse order through vector)
for (auto it = affine_if_nest.rbegin(); it != affine_if_nest.rend(); ++it) {
auto affine_if = dyn_cast<affine::AffineIfOp>(*it);
// Get then integerset sizes
SmallVector<int, 2> lbs_int = {0, 0};
SmallVector<int, 2> ubs_int = {0, 0};
IntegerSet int_set = affine_if.getIntegerSet();
air::getSizesFromIntegerSet(affine_if->getContext(), int_set, lbs_int,
ubs_int);
// If found then block containing op
SmallVector<std::pair<int, int>> conditionBounds; // <LB, UB>
if (affine_if.getThenBlock()->findAncestorOpInBlock(*op)) {
for (unsigned i = 0; i < lbs_int.size(); i++)
conditionBounds.push_back(std::make_pair(lbs_int[i], ubs_int[i]));
return conditionBounds;
}
// Else keep going, while updating the spatial sizes wrt else condition
else {
air::getElseSizesFromAffineIf(lbs_spatial, ubs_spatial, lbs_int, ubs_int);
}
}

// If op isn't in any then blocks in affine.if nest
SmallVector<std::pair<int, int>> conditionBounds; // <LB, UB>
for (unsigned i = 0; i < lbs_spatial.size(); i++)
conditionBounds.push_back(std::make_pair(lbs_spatial[i], ubs_spatial[i]));
return conditionBounds;
}

// Evaluate the integer value of affine set expression if the only symbolic
// identifier is replaced with zero
int air::evaluateSymbolEqualityInSet(AffineExpr c, MLIRContext *ctx) {
assert(c.isSymbolicOrConstant() && "constraint has dimension identifier");
SmallVector<AffineExpr, 2> zero_syms{
getAffineConstantExpr(0, ctx),
getAffineConstantExpr(0, ctx),
};
auto newC = c.replaceSymbols(zero_syms);
auto expr = dyn_cast<AffineConstantExpr>(simplifyAffineExpr(newC, 0, 1));
assert(expr);
int result = expr.getValue();
// Both + and - constant eval are legal for AffineExpr
return (result >= 0) ? (result) : (-result);
}

// Check if an operand of an operation is read or write access
char air::checkOpOperandReadOrWrite(Value v, Operation *owner) {
for (auto &op_operand : owner->getOpOperands()) {
Expand Down
81 changes: 81 additions & 0 deletions mlir/test/Conversion/ConvertToAIR/dma_to_channel_async.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -221,3 +221,84 @@ module {
return
}
}

// -----

// CHECK-LABEL: @func2
// CHECK: air.segment
// CHECK: scf.for
// CHECK: scf.for
// CHECK: scf.for

// CHECK: %[[TOK0:.*]], %[[EXEC0:.*]] = air.execute
// CHECK-NEXT: affine.apply {{.*}}[%{{.*}}, %c0{{.*}}]
// CHECK: air.channel.put async{{.*}}@channel_0[] (%{{.*}}[%[[EXEC0]],
// CHECK: %[[TOK1:.*]], %[[EXEC1:.*]] = air.execute
// CHECK-NEXT: affine.apply {{.*}}[%{{.*}}, %c1{{.*}}]
// CHECK: air.channel.put async{{.*}}@channel_1[] (%{{.*}}[%[[EXEC1]],
// CHECK: air.herd

#map = affine_map<()[s0, s1] -> (s0 + s1)>
#set = affine_set<()[s0, s1] : (s0 == 0, s1 >= 0, -s1 + 1 >= 0)>
#set1 = affine_set<()[s0, s1] : (s0 - 1 == 0, s1 >= 0, -s1 + 1 >= 0)>
module {
func.func @func2() {
%c2 = arith.constant 2 : index
%c16 = arith.constant 16 : index
%0 = air.launch async (%arg3, %arg4) in (%arg5=%c2, %arg6=%c16) attributes {id = 5 : i32} {
%1 = air.segment @segment_0 async attributes {id = 4 : i32} {
%c15 = arith.constant 15 : index
%c4 = arith.constant 4 : index
%c2_0 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%c8 = arith.constant 8 : index
%c0 = arith.constant 0 : index
%async_token, %results = air.execute -> (memref<1x1x4x8x4x8xbf16, 2 : i32>) {
%alloc = memref.alloc() : memref<1x1x4x8x4x8xbf16, 2 : i32>
air.execute_terminator %alloc : memref<1x1x4x8x4x8xbf16, 2 : i32>
} {id = 7 : i32}
%async_token_1, %results_2 = air.execute -> (memref<8x16x32x32xbf16, 1 : i32>) {
%alloc = memref.alloc() : memref<8x16x32x32xbf16, 1 : i32>
air.execute_terminator %alloc : memref<8x16x32x32xbf16, 1 : i32>
} {id = 9 : i32}
scf.for %arg7 = %c0 to %c8 step %c4 {
scf.for %arg8 = %c0 to %c8 step %c4 {
%2 = scf.for %arg9 = %c1 to %c15 step %c1 iter_args(%arg10 = %async_token) -> (!air.async.token) {
%3 = air.herd @herd_0 async [%arg10] tile (%arg11, %arg12) in (%arg13=%c2_0, %arg14=%c2_0) args(%arg15=%arg7, %arg16=%results, %arg17=%results_2, %arg18=%arg9) : index, memref<1x1x4x8x4x8xbf16, 2 : i32>, memref<8x16x32x32xbf16, 1 : i32>, index attributes {id = 2 : i32, link_with = "mm.o"} {
%c0_5 = arith.constant 0 : index
%c1_6 = arith.constant 1 : index
%c4_7 = arith.constant 4 : index
%c8_8 = arith.constant 8 : index
%c16384 = arith.constant 16384 : index
%c1024 = arith.constant 1024 : index
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%async_token_9, %results_10 = air.execute -> (index) {
%5 = affine.apply #map()[%arg15, %arg11]
air.execute_terminator %5 : index
} {id = 16 : i32}
%4 = affine.if #set()[%arg11, %arg12] -> !air.async.token {
%5 = air.dma_memcpy_nd async [%async_token_9] (%arg16[] [] [], %arg17[%results_10, %arg18, %c0_5, %c0_5, %c0_5, %c0_5] [%c1_6, %c1_6, %c4_7, %c8_8, %c4_7, %c8_8] [%c16384, %c1024, %c8_8, %c128, %c32, %c1_6]) {broadcast_set = #set, id = 11 : i32} : (memref<1x1x4x8x4x8xbf16, 2 : i32>, memref<8x16x32x32xbf16, 1 : i32>)
affine.yield %5 : !air.async.token
} else {
%5 = air.dma_memcpy_nd async [%async_token_9] (%arg16[] [] [], %arg17[%results_10, %arg18, %c0_5, %c0_5, %c0_5, %c0_5] [%c1_6, %c1_6, %c4_7, %c8_8, %c4_7, %c8_8] [%c16384, %c1024, %c8_8, %c128, %c32, %c1_6]) {broadcast_set = #set1, id = 12 : i32} : (memref<1x1x4x8x4x8xbf16, 2 : i32>, memref<8x16x32x32xbf16, 1 : i32>)
affine.yield %5 : !air.async.token
}
}
scf.yield %3 : !air.async.token
}
}
}
%async_token_3 = air.execute {
memref.dealloc %results_2 : memref<8x16x32x32xbf16, 1 : i32>
} {id = 22 : i32}
%async_token_4 = air.execute [%async_token] {
memref.dealloc %results : memref<1x1x4x8x4x8xbf16, 2 : i32>
} {id = 24 : i32}
}
}
return
}
}


0 comments on commit e7f682f

Please sign in to comment.