Skip to content

Commit

Permalink
AIRSplitL2Memref: Splitting with temporally fused air.channel (#794)
Browse files Browse the repository at this point in the history
* Enable splitting channels with puts and gets over time (i.e. fused channels)

* Unit test for splitting channels with puts and gets over time
  • Loading branch information
erwei-xilinx authored Nov 25, 2024
1 parent eb17d1b commit 8f8a1d5
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 4 deletions.
28 changes: 24 additions & 4 deletions mlir/lib/Transform/AIRMiscPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1479,6 +1479,18 @@ void AIRSplitL2MemrefForBufferConstraintPass::runOnOperation() {
if (targetMemrefs.empty())
return;

// Look up or create a new air.channel name.
std::map<std::string, std::string> chanNameMap;
auto lookUpOrCreateChanName = [&](std::string chanName, ModuleOp module) {
if (chanNameMap.count(chanName))
return chanNameMap[chanName];
else {
auto newChanName = air::createChannelName(module);
chanNameMap[chanName] = newChanName;
return newChanName;
}
};

// Tile memrefs.
llvm::SmallSet<Operation *, 1> erased;
for (auto allocOp : targetMemrefs) {
Expand Down Expand Up @@ -1528,10 +1540,18 @@ void AIRSplitL2MemrefForBufferConstraintPass::runOnOperation() {
SmallVector<Type, 4> tys = {
air::AsyncTokenType::get(chanUserOp->getContext())};
auto cname =
air::createChannelName(chanUserOp->getParentOfType<ModuleOp>());
SmallVector<int64_t, 2> channel_sizes = {targetColTilingFactor, 1};
auto new_chan = builder.create<air::ChannelOp>(
loc, cname, builder.getI64ArrayAttr(channel_sizes));
lookUpOrCreateChanName(chanUserOp.getChanName().str(),
chanUserOp->getParentOfType<ModuleOp>());
air::ChannelOp new_chan;
auto new_chan_op = mlir::SymbolTable::lookupSymbolIn(
chanUserOp->getParentOfType<ModuleOp>(), cname);
if (new_chan_op) {
new_chan = dyn_cast<air::ChannelOp>(new_chan_op);
} else {
SmallVector<int64_t, 2> channel_sizes = {targetColTilingFactor, 1};
new_chan = builder.create<air::ChannelOp>(
loc, cname, builder.getI64ArrayAttr(channel_sizes));
}
auto memrefShape = air::getTensorShape(memref.getType());

int dim = 0;
Expand Down
131 changes: 131 additions & 0 deletions mlir/test/Transform/AIRMiscPasses/air_split_l2_memref.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1490,3 +1490,134 @@ module {
return
}
}


// -----

// Splitting channels which have puts and gets which span over multiple time phases.

// CHECK-LABEL: func.func @test11
// CHECK: air.channel.put {{.*}} @channel_0[%c0{{.*}}, %c0{{.*}}]
// CHECK: air.channel.put {{.*}} @channel_0[%c1{{.*}}, %c0{{.*}}]
// CHECK: air.channel.put {{.*}} @channel_0[%c2{{.*}}, %c0{{.*}}]
// CHECK: air.channel.put {{.*}} @channel_0[%c3{{.*}}, %c0{{.*}}]
// CHECK: air.channel.put {{.*}} @channel_0[%c0{{.*}}, %c0{{.*}}]
// CHECK: air.channel.put {{.*}} @channel_0[%c1{{.*}}, %c0{{.*}}]
// CHECK: air.channel.put {{.*}} @channel_0[%c2{{.*}}, %c0{{.*}}]
// CHECK: air.channel.put {{.*}} @channel_0[%c3{{.*}}, %c0{{.*}}]

// CHECK: air.segment

// CHECK: air.channel.get {{.*}} @channel_0[%c0{{.*}}, %c0{{.*}}]
// CHECK: air.channel.get {{.*}} @channel_0[%c1{{.*}}, %c0{{.*}}]
// CHECK: air.channel.get {{.*}} @channel_0[%c2{{.*}}, %c0{{.*}}]
// CHECK: air.channel.get {{.*}} @channel_0[%c3{{.*}}, %c0{{.*}}]
// CHECK: air.channel.get {{.*}} @channel_0[%c0{{.*}}, %c0{{.*}}]
// CHECK: air.channel.get {{.*}} @channel_0[%c1{{.*}}, %c0{{.*}}]
// CHECK: air.channel.get {{.*}} @channel_0[%c2{{.*}}, %c0{{.*}}]
// CHECK: air.channel.get {{.*}} @channel_0[%c3{{.*}}, %c0{{.*}}]

#map = affine_map<()[s0] -> (s0 * 128)>
#map1 = affine_map<()[s0] -> (s0 * 32)>
module {
air.channel @channel_4 [1, 1]
air.channel @channel_6 [4, 4]
func.func @test11(%arg0: memref<1024x512xi32>, %arg1: memref<512x1024xi32>, %arg2: memref<1024x1024xi32>, %arg3: memref<1024x1024xi32>) {
%c8 = arith.constant 8 : index
%0 = air.launch async (%arg4, %arg5) in (%arg6=%c8, %arg7=%c8) args(%arg8=%arg0, %arg9=%arg1) : memref<1024x512xi32>, memref<512x1024xi32> attributes {id = 1 : i32} {
%c16 = arith.constant 16 : index
%c32768 = arith.constant 32768 : index
%c1024 = arith.constant 1024 : index
%c512 = arith.constant 512 : index
%c32 = arith.constant 32 : index
%c16384 = arith.constant 16384 : index
%c0 = arith.constant 0 : index
%c4 = arith.constant 4 : index
%c1 = arith.constant 1 : index
%async_token, %results = air.execute -> (index) {
%4 = affine.apply #map()[%arg4]
air.execute_terminator %4 : index
}
%1 = air.wait_all async
%2 = scf.for %arg10 = %c0 to %c16 step %c1 iter_args(%arg11 = %1) -> (!air.async.token) {
%async_token_0, %results_1 = air.execute [%arg11] -> (index) {
%6 = affine.apply #map1()[%arg10]
air.execute_terminator %6 : index
}
%4 = air.channel.put async [%async_token_0, %async_token] @channel_4[] (%arg8[%c0, %c0, %results, %results_1] [%c4, %c1, %c32, %c32] [%c16384, %c32, %c512, %c1]) {id = 3 : i32} : (memref<1024x512xi32>)
%async_token_2, %results_3 = air.execute [%arg11] -> (index) {
%6 = affine.apply #map1()[%arg10]
air.execute_terminator %6 : index
}
%async_token_4, %results_5 = air.execute -> (index) {
%6 = affine.apply #map()[%arg5]
air.execute_terminator %6 : index
}
%5 = air.channel.put async [%4, %async_token_2, %async_token_4] @channel_4[] (%arg9[%c0, %c0, %results_3, %results_5] [%c1, %c4, %c32, %c32] [%c32768, %c32, %c1024, %c1]) {id = 4 : i32} : (memref<512x1024xi32>)
scf.yield %4 : !air.async.token
}
%3 = air.segment @matmul_elementwise_i32_dispatch_0_matmul_1024x1024x512_i32_0 async attributes {id = 2 : i32} {
%c16_0 = arith.constant 16 : index
%c256 = arith.constant 256 : index
%c8_1 = arith.constant 8 : index
%c128 = arith.constant 128 : index
%c4096 = arith.constant 4096 : index
%c1024_2 = arith.constant 1024 : index
%c32_3 = arith.constant 32 : index
%c0_4 = arith.constant 0 : index
%c4_5 = arith.constant 4 : index
%c1_6 = arith.constant 1 : index
%async_token_17, %results_18 = air.execute -> (memref<1x1x8x4x8x4xi32, 2 : i32>) {
%alloc = memref.alloc() : memref<1x1x8x4x8x4xi32, 2 : i32>
air.execute_terminator %alloc : memref<1x1x8x4x8x4xi32, 2 : i32>
}
%async_token_19, %results_20 = air.execute -> (memref<1x1x4x8x4x8xi32, 2 : i32>) {
%alloc = memref.alloc() : memref<1x1x4x8x4x8xi32, 2 : i32>
air.execute_terminator %alloc : memref<1x1x4x8x4x8xi32, 2 : i32>
}
%async_token_21, %results_22 = air.execute -> (memref<1x4x32x32xi32, 1 : i32>) {
%alloc = memref.alloc() : memref<1x4x32x32xi32, 1 : i32>
air.execute_terminator %alloc : memref<1x4x32x32xi32, 1 : i32>
}
%async_token_23, %results_24 = air.execute -> (memref<4x1x32x32xi32, 1 : i32>) {
%alloc = memref.alloc() : memref<4x1x32x32xi32, 1 : i32>
air.execute_terminator %alloc : memref<4x1x32x32xi32, 1 : i32>
}
%4 = air.herd @herd_0 async [%async_token_17, %async_token_19, %async_token_21, %async_token_23] tile (%arg10, %arg11) in (%arg12=%c4_5, %arg13=%c4_5) args(%arg14=%results_20, %arg15=%results_18) : memref<1x1x4x8x4x8xi32, 2 : i32>, memref<1x1x8x4x8x4xi32, 2 : i32> attributes {id = 3 : i32} {
%7 = air.channel.get async @channel_6[%arg10, %arg11] (%arg14[] [] []) {id = 13 : i32} : (memref<1x1x4x8x4x8xi32, 2 : i32>)
%8 = air.channel.get async @channel_6[%arg10, %arg11] (%arg15[] [] []) {id = 14 : i32} : (memref<1x1x8x4x8x4xi32, 2 : i32>)
}
%5 = scf.for %arg10 = %c0_4 to %c16_0 step %c1_6 iter_args(%arg11 = %async_token_23) -> (!air.async.token) {
%7 = air.channel.get async [%arg11] @channel_4[] (%results_24[] [] []) {id = 15 : i32} : (memref<4x1x32x32xi32, 1 : i32>)
%8 = air.channel.get async [%7] @channel_4[] (%results_22[] [] []) {id = 16 : i32} : (memref<1x4x32x32xi32, 1 : i32>)
scf.yield %7 : !air.async.token
}
%6 = scf.for %arg10 = %c0_4 to %c16_0 step %c1_6 iter_args(%arg11 = %async_token_23) -> (!air.async.token) {
%7 = scf.parallel (%arg12, %arg13) = (%c0_4, %c0_4) to (%c4_5, %c4_5) step (%c1_6, %c1_6) init (%arg11) -> !air.async.token {
%8 = air.channel.put async [%arg11] @channel_6[%arg12, %arg13] (%results_24[%arg12, %c0_4, %c0_4, %c0_4, %c0_4, %c0_4] [%c1_6, %c1_6, %c4_5, %c8_1, %c4_5, %c8_1] [%c1024_2, %c1024_2, %c8_1, %c128, %c32_3, %c1_6]) {id = 17 : i32} : (memref<4x1x32x32xi32, 1 : i32>)
%9 = air.channel.put async [%8] @channel_6[%arg12, %arg13] (%results_22[%c0_4, %arg13, %c0_4, %c0_4, %c0_4, %c0_4] [%c1_6, %c1_6, %c8_1, %c4_5, %c8_1, %c4_5] [%c4096, %c1024_2, %c4_5, %c256, %c32_3, %c1_6]) {id = 18 : i32} : (memref<1x4x32x32xi32, 1 : i32>)
scf.reduce(%8 : !air.async.token) {
^bb0(%arg14: !air.async.token, %arg15: !air.async.token):
%10 = air.wait_all async [%arg14, %arg15]
scf.reduce.return %10 : !air.async.token
}
}
scf.yield %7 : !air.async.token
}
%async_token_25 = air.execute [%6] {
memref.dealloc %results_24 : memref<4x1x32x32xi32, 1 : i32>
}
%async_token_26 = air.execute [%6] {
memref.dealloc %results_22 : memref<1x4x32x32xi32, 1 : i32>
}
%async_token_27 = air.execute [%6] {
memref.dealloc %results_20 : memref<1x1x4x8x4x8xi32, 2 : i32>
}
%async_token_28 = air.execute [%6] {
memref.dealloc %results_18 : memref<1x1x8x4x8x4xi32, 2 : i32>
}
}
}
return
}
}

0 comments on commit 8f8a1d5

Please sign in to comment.