Skip to content

Commit

Permalink
Add depth=-1 mode to air-par-to-herd/segment/launch (#866)
Browse files Browse the repository at this point in the history
* Add depth=-1 mode to air-par-to-herd/segment/launch representing converting the innermost loop body

* Enable depth=-1 for forall
  • Loading branch information
erwei-xilinx authored Jan 20, 2025
1 parent 8ead7cb commit 6e686c7
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 9 deletions.
18 changes: 12 additions & 6 deletions mlir/include/air/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@ def ParallelToHerd : Pass<"air-par-to-herd", "ModuleOp"> {
}];
let options = [
Option<"clAssignDepth", "depth", "int",
/*default=*/"-1",
"Given a nest of parallel for loops, which depth to map to air.herd">,
/*default=*/"-2",
"Given a nest of parallel for loops, which depth to map to air.herd. "
"-1 means converting the innermost parallel loop; any other negative "
"value means converting all parallel loops">,
Option<"clFirstDim", "first-dim", "int",
/*default=*/"0",
"Which herd dimension to map to first. Can be zero or one. If set to "
Expand All @@ -49,8 +51,10 @@ def ParallelToLaunch : Pass<"air-par-to-launch", "ModuleOp"> {
}];
let options = [
Option<"clAssignDepth", "depth", "int",
/*default=*/"-1",
"Given a nest of parallel for loops, which depth to map to air.launch">,
/*default=*/"-2",
"Given a nest of parallel for loops, which depth to map to air.launch"
"-1 means converting the innermost parallel loop; any other negative "
"value means converting all parallel loops">,
Option<"clHasSegment", "has-air-segment", "bool", /*default=*/"false",
"Whether to create an air.segment op in generated air.launch "
"regions">,
Expand All @@ -68,8 +72,10 @@ def ParallelToSegment : Pass<"air-par-to-segment", "ModuleOp"> {
}];
let options = [
Option<"clAssignDepth", "depth", "int",
/*default=*/"-1",
"Given a nest of parallel for loops, which depth to map to air.segment">,
/*default=*/"-2",
"Given a nest of parallel for loops, which depth to map to air.segment"
"-1 means converting the innermost parallel loop; any other negative "
"value means converting all parallel loops">,
];
}

Expand Down
48 changes: 45 additions & 3 deletions mlir/lib/Conversion/ConvertToAIRPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1165,11 +1165,25 @@ struct ParallelToHerdPass
if (llvm::any_of(hierOps,
[op](Operation *h) { return op->isProperAncestor(h); }))
return;
// Depth = -1 means converting the innermost parallel ops
if (clAssignDepth == -1) {
SmallVector<Operation *> parOpsInOp;
op->walk([&parOpsInOp](Operation *o) {
if (isa<scf::ForallOp, scf::ParallelOp, affine::AffineParallelOp>(o))
parOpsInOp.push_back(o);
});
if (parOpsInOp.size() > 1)
return;
filteredOps.insert(op);
return;
}
// Assigning depth to other negative values means converting all
// parallel ops
if (clAssignDepth < 0) {
filteredOps.insert(op);
return;
}
// the number of nested scf.parallel above this one
// the number of nested parallel above this one
int parallel_depth = 0;
Operation *par = op;
while ((par = par->getParentOp()))
Expand Down Expand Up @@ -1253,11 +1267,25 @@ struct ParallelToLaunchPass
return op->isProperAncestor(l);
}))
return;
// Depth = -1 means converting the innermost parallel ops
if (clAssignDepth == -1) {
SmallVector<Operation *> parOpsInOp;
op->walk([&parOpsInOp](Operation *o) {
if (isa<scf::ParallelOp>(o))
parOpsInOp.push_back(o);
});
if (parOpsInOp.size() > 1)
return;
filteredOps.insert(op);
return;
}
// Assigning depth to other negative values means converting all
// parallel ops
if (clAssignDepth < 0) {
filteredOps.insert(op);
return;
}
// the number of nested scf.parallel above this one
// the number of nested parallel above this one
int parallel_depth = 0;
Operation *par = op;
while ((par = par->getParentOp()))
Expand Down Expand Up @@ -1342,11 +1370,25 @@ struct ParallelToSegmentPass
return op->isProperAncestor(s);
}))
return;
// Depth = -1 means converting the innermost parallel ops
if (clAssignDepth == -1) {
SmallVector<Operation *> parOpsInOp;
op->walk([&parOpsInOp](Operation *o) {
if (isa<scf::ParallelOp>(o))
parOpsInOp.push_back(o);
});
if (parOpsInOp.size() > 1)
return;
filteredOps.insert(op);
return;
}
// Assigning depth to other negative values means converting all
// parallel ops
if (clAssignDepth < 0) {
filteredOps.insert(op);
return;
}
// the number of nested scf.parallel above this one
// the number of nested parallel above this one
int parallel_depth = 0;
Operation *par = op;
while ((par = par->getParentOp()))
Expand Down
24 changes: 24 additions & 0 deletions mlir/test/Conversion/ConvertToAIR/scf_forall_to_herd.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
//===----------------------------------------------------------------------===//

// RUN: air-opt -split-input-file -verify-diagnostics -air-par-to-herd %s | FileCheck %s
// RUN: air-opt -split-input-file -verify-diagnostics -air-par-to-herd="depth=-1" %s | FileCheck %s --check-prefix=DEPTHM1
// RUN: air-opt -split-input-file -verify-diagnostics -air-par-to-herd="depth=0" %s | FileCheck %s --check-prefix=DEPTH0
// RUN: air-opt -split-input-file -verify-diagnostics -air-par-to-herd="depth=1" %s | FileCheck %s --check-prefix=DEPTH1

// CHECK-LABEL: func.func @scf0() {
// CHECK: air.herd @herd_0 tile (%{{.*}}, %{{.*}}) in (%{{.*}}=%c2{{.*}}, %{{.*}}=%c2{{.*}})
Expand Down Expand Up @@ -98,6 +101,27 @@ func.func @scf4() {
// CHECK: }
// CHECK: }
// CHECK: }
// DEPTHM1-LABEL: func.func @scf5() {
// DEPTHM1: scf.forall {{.*}} {
// DEPTHM1: scf.forall {{.*}} {
// DEPTHM1: air.herd @herd_{{.*}} {
// DEPTHM1: }
// DEPTHM1: }
// DEPTHM1: }
// DEPTH0-LABEL: func.func @scf5() {
// DEPTH0: air.herd @herd_{{.*}} {
// DEPTH0: scf.forall {{.*}} {
// DEPTH0: scf.forall {{.*}} {
// DEPTH0: }
// DEPTH0: }
// DEPTH0: }
// DEPTH1-LABEL: func.func @scf5() {
// DEPTH1: scf.forall {{.*}} {
// DEPTH1: air.herd @herd_{{.*}} {
// DEPTH1: scf.forall {{.*}} {
// DEPTH1: }
// DEPTH1: }
// DEPTH1: }
func.func @scf5() {
%src = memref.alloc() : memref<4x4x4xi32, 2 : i32>
%dst = memref.alloc() : memref<4x4x4xi32, 2 : i32>
Expand Down
76 changes: 76 additions & 0 deletions mlir/test/Conversion/ConvertToAIR/scf_parallel_to_herd.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
//===----------------------------------------------------------------------===//

// RUN: air-opt -split-input-file -verify-diagnostics -air-par-to-herd %s | FileCheck %s
// RUN: air-opt -split-input-file -verify-diagnostics -air-par-to-herd="depth=-1" %s | FileCheck %s --check-prefix=DEPTHM1
// RUN: air-opt -split-input-file -verify-diagnostics -air-par-to-herd="depth=0" %s | FileCheck %s --check-prefix=DEPTH0

// CHECK-LABEL: func.func @scf0() {
// CHECK: %[[C2:.*]] = arith.constant 2 : index
Expand Down Expand Up @@ -201,6 +203,34 @@ module {
// CHECK: return
// CHECK: }
// CHECK: }
// DEPTHM1-LABEL: @shared_herd_name
// DEPTHM1: scf.parallel {{.*}} {
// DEPTHM1: air.herd @herd_0
// DEPTHM1: }
// DEPTHM1: air.herd @herd_0
// DEPTHM1: }
// DEPTHM1: air.herd @herd_0
// DEPTHM1: }
// DEPTHM1: scf.reduce
// DEPTHM1: }
// DEPTHM1: return
// DEPTHM1: }
// DEPTHM1: }
// DEPTH0-LABEL: @shared_herd_name
// DEPTH0: air.herd @herd_0
// DEPTH0: scf.parallel {{.*}}
// DEPTH0: scf.reduce
// DEPTH0: }
// DEPTH0: scf.parallel {{.*}}
// DEPTH0: scf.reduce
// DEPTH0: }
// DEPTH0: scf.parallel {{.*}}
// DEPTH0: scf.reduce
// DEPTH0: }
// DEPTH0: }
// DEPTH0: return
// DEPTH0: }
// DEPTH0: }
module {
func.func @shared_herd_name(%arg0: memref<512x1024xbf16>, %arg1: memref<1024x512xbf16>, %arg2: memref<512x512xbf16>) {
%c32 = arith.constant 32 : index
Expand Down Expand Up @@ -248,6 +278,29 @@ module {
// CHECK: return
// CHECK: }
// CHECK: }
// DEPTHM1-LABEL: @unique_herd_name
// DEPTHM1: scf.parallel {{.*}} {
// DEPTHM1: air.herd @herd_0
// DEPTHM1: }
// DEPTHM1: air.herd @herd_1
// DEPTHM1: }
// DEPTHM1: scf.reduce
// DEPTHM1: }
// DEPTHM1: return
// DEPTHM1: }
// DEPTHM1: }
// DEPTH0-LABEL: @unique_herd_name
// DEPTH0: air.herd @herd_0
// DEPTH0: scf.parallel {{.*}} {
// DEPTH0: scf.reduce
// DEPTH0: }
// DEPTH0: scf.parallel {{.*}} {
// DEPTH0: scf.reduce
// DEPTH0: }
// DEPTH0: }
// DEPTH0: return
// DEPTH0: }
// DEPTH0: }
module {
func.func @unique_herd_name(%arg0: memref<512x1024xbf16>, %arg1: memref<1024x512xbf16>, %arg2: memref<512x512xbf16>) {
%c32 = arith.constant 32 : index
Expand Down Expand Up @@ -303,6 +356,29 @@ module {
// CHECK: return
// CHECK: }
// CHECK: }
// DEPTHM1-LABEL: @l2_to_l1_dma_infer_herd
// DEPTHM1: scf.parallel {{.*}} {
// DEPTHM1: air.herd @herd_0
// DEPTHM1: }
// DEPTHM1: air.herd @herd_0
// DEPTHM1: }
// DEPTHM1: scf.reduce
// DEPTHM1: }
// DEPTHM1: return
// DEPTHM1: }
// DEPTHM1: }
// DEPTH0-LABEL: @l2_to_l1_dma_infer_herd
// DEPTH0: air.herd @herd_0
// DEPTH0: scf.parallel {{.*}} {
// DEPTH0: scf.reduce
// DEPTH0: }
// DEPTH0: scf.parallel {{.*}} {
// DEPTH0: scf.reduce
// DEPTH0: }
// DEPTH0: }
// DEPTH0: return
// DEPTH0: }
// DEPTH0: }
module {
func.func @l2_to_l1_dma_infer_herd() {
%c32 = arith.constant 32 : index
Expand Down

0 comments on commit 6e686c7

Please sign in to comment.