Skip to content

Commit

Permalink
AIRCanonicalize: Fixup dependency canonicalizer when air.wait_all w…
Browse files Browse the repository at this point in the history
…as a source (#882)

* Fixup dependency canonicalizer where wait_all should never be the source for memref RW analysis

* Add more mlir ir tests showing the memref RW-based dependency canonicalizer
  • Loading branch information
erwei-xilinx authored Jan 24, 2025
1 parent d411747 commit f742830
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 5 deletions.
21 changes: 16 additions & 5 deletions mlir/lib/Dialect/AIR/IR/AIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -295,12 +295,23 @@ static LogicalResult CanonicalizeAsyncOpDeps(OpT op,
auto memrefsReadBySinkOp = getAllMemrefsReadByOp(op.getOperation());
auto memrefsWrittenBySinkOp = getAllMemrefsWrittenByOp(op.getOperation());
// make a list of new async token operands
std::function<void(SmallVector<Value>, SmallVector<Value> &)>
getDirectDependenciesGreedily;
getDirectDependenciesGreedily = [&getDirectDependenciesGreedily](
SmallVector<Value> depList,
SmallVector<Value> &directDeps) {
for (auto v : depList) {
if (auto wa = dyn_cast_if_present<air::WaitAllOp>(v.getDefiningOp()))
getDirectDependenciesGreedily(wa.getAsyncDependencies(), directDeps);
else
directDeps.push_back(v);
}
return;
};
llvm::SetVector<Value> newAsyncDeps; // don't include duplicates
for (auto v : op.getAsyncDependencies()) {
// don't include wait_all ops with no operands
if (auto wa = dyn_cast_if_present<WaitAllOp>(v.getDefiningOp()))
if (wa.getAsyncDependencies().size() == 0)
continue;
SmallVector<Value> directDeps;
getDirectDependenciesGreedily(op.getAsyncDependencies(), directDeps);
for (auto v : directDeps) {
// don't include any false dependencies, i.e. sink does not depend on source
// in RAW, WAR or WAW; RAR is a false dependency
if (v.getDefiningOp()) {
Expand Down
42 changes: 42 additions & 0 deletions mlir/test/Dialect/AIR/air_canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,48 @@ func.func @chan_0(%arg0 : memref<4x1x64x64xbf16>, %arg1 : memref<1x4x64x64xbf16>
return
}

// CHECK: func.func @chan_1
// CHECK: %[[TOKEN0:.*]] = air.channel.put async @channel_0
// CHECK: %[[TOKEN1:.*]] = air.channel.put async @channel_1
func.func @chan_1(%arg0 : memref<4x1x64x64xbf16>) {
%1 = air.channel.put async @channel_0[] (%arg0[] [] []) : (memref<4x1x64x64xbf16>)
%2 = air.channel.put async [%1] @channel_1[] (%arg0[] [] []) : (memref<4x1x64x64xbf16>)
return
}

// CHECK: func.func @chan_2
// CHECK: %[[TOKEN0:.*]] = air.channel.get async @channel_0
// CHECK: %[[TOKEN1:.*]] = air.channel.get async [%[[TOKEN0]]] @channel_1
func.func @chan_2(%arg0 : memref<4x1x64x64xbf16>) {
%0 = air.channel.get async @channel_0[] (%arg0[] [] []) : (memref<4x1x64x64xbf16>)
%1 = air.channel.get async [%0] @channel_1[] (%arg0[] [] []) : (memref<4x1x64x64xbf16>)
return
}

// CHECK: func.func @chan_3
// CHECK: %[[TOKEN0:.*]] = air.channel.get async @channel_0
// CHECK: %[[TOKEN1:.*]] = air.channel.put async [%[[TOKEN0]]] @channel_1
func.func @chan_3(%arg0 : memref<4x1x64x64xbf16>) {
%0 = air.channel.get async @channel_0[] (%arg0[] [] []) : (memref<4x1x64x64xbf16>)
%1 = air.wait_all async [%0]
%2 = air.channel.put async [%1] @channel_1[] (%arg0[] [] []) : (memref<4x1x64x64xbf16>)
return
}

// CHECK: func.func @chan_4
// CHECK: %[[TOKEN0:.*]] = air.channel.get async @channel_0
// CHECK: %[[TOKEN1:.*]] = air.channel.get async @channel_1
// CHECK: %[[TOKEN2:.*]] = air.channel.put async [%[[TOKEN0]]] @channel_2
// CHECK: %[[TOKEN3:.*]] = air.channel.put async [%[[TOKEN1]]] @channel_3
func.func @chan_4(%arg0 : memref<4x1x64x64xbf16>, %arg1 : memref<4x1x64x64xbf16>) {
%0 = air.channel.get async @channel_0[] (%arg0[] [] []) : (memref<4x1x64x64xbf16>)
%1 = air.channel.get async @channel_1[] (%arg1[] [] []) : (memref<4x1x64x64xbf16>)
%2 = air.wait_all async [%0, %1]
%3 = air.channel.put async [%2] @channel_2[] (%arg0[] [] []) : (memref<4x1x64x64xbf16>)
%4 = air.channel.put async [%2] @channel_3[] (%arg1[] [] []) : (memref<4x1x64x64xbf16>)
return
}

// CHECK: func.func @dma_compose_subview
// CHECK: air.dma_memcpy_nd (%{{.*}}[%c0{{.*}}, %c0{{.*}}] [%c32{{.*}}, %c32{{.*}}] [%c64{{.*}}, %c1{{.*}}], %{{.*}}[%c0{{.*}}, %c0{{.*}}] [%c32{{.*}}, %c32{{.*}}] [%c64{{.*}}, %c1{{.*}}]
func.func @dma_compose_subview() {
Expand Down

0 comments on commit f742830

Please sign in to comment.