Skip to content

Commit

Permalink
Add logical objectFifo hoisting pass
Browse files Browse the repository at this point in the history
  • Loading branch information
jtuyls committed Aug 15, 2024
1 parent 8b9c73e commit 907b52a
Show file tree
Hide file tree
Showing 10 changed files with 282 additions and 5 deletions.
7 changes: 3 additions & 4 deletions compiler/plugins/target/AMD-AIE/iree-amd-aie/IR/AMDAIEOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,16 @@
#ifndef IREE_COMPILER_AMDAIE_OPS_H_
#define IREE_COMPILER_AMDAIE_OPS_H_

#include "iree-amd-aie/IR/AMDAIEAttrs.h"
#include "iree-amd-aie/IR/AMDAIEDmaOpInterface.h"
#include "iree-amd-aie/IR/AMDAIETypes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Interfaces/CopyOpInterface.h"
#include "mlir/Interfaces/ViewLikeInterface.h"

#include "iree-amd-aie/IR/AMDAIEAttrs.h"
#include "iree-amd-aie/IR/AMDAIEDmaOpInterface.h"
#include "iree-amd-aie/IR/AMDAIETypes.h"

// clang-format off
#include "iree-amd-aie/IR/AMDAIEAttrs.h"
#include "iree-amd-aie/IR/AMDAIEDialect.h"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree-amd-aie/IR/AMDAIEOps.h"
#include "iree-amd-aie/Transforms/Passes.h"
#include "iree-amd-aie/Transforms/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"

#define DEBUG_TYPE "iree-amdaie-hoist-logical-objectfifo"

namespace mlir::iree_compiler::AMDAIE {

/// Hoist logical objectFifo operations until one of the operands is located
/// within the same scope.
LogicalResult hoistLogicalObjFifoOp(RewriterBase &rewriter,
AMDAIE::LogicalObjectFifoFromMemrefOp op) {
Operation *parentOp = op;
while (parentOp) {
Operation *newParentOp = parentOp->getParentOp();
if (llvm::any_of(op->getOperands(), [&](Value operand) {
return operand.getDefiningOp() &&
newParentOp->isProperAncestor(operand.getDefiningOp());
})) {
break;
}
if (isa<AMDAIE::WorkgroupOp, AMDAIE::ControlCodeOp, func::FuncOp>(
newParentOp)) {
break;
}
parentOp = newParentOp;
}
if (parentOp && parentOp != op) rewriter.moveOpBefore(op, parentOp);
return failure();
}

namespace {
struct AMDAIEHoistLogicalObjFifoPass
: public impl::AMDAIEHoistLogicalObjFifoBase<
AMDAIEHoistLogicalObjFifoPass> {
void runOnOperation() override;
};

void AMDAIEHoistLogicalObjFifoPass::runOnOperation() {
Operation *parentOp = getOperation();
IRRewriter rewriter(parentOp->getContext());
parentOp->walk([&](AMDAIE::LogicalObjectFifoFromMemrefOp op) {
(void)hoistLogicalObjFifoOp(rewriter, op);
});
}

} // namespace

std::unique_ptr<Pass> createAMDAIEHoistLogicalObjFifoPass() {
return std::make_unique<AMDAIEHoistLogicalObjFifoPass>();
}

} // namespace mlir::iree_compiler::AMDAIE
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ iree_cc_library(
"AMDAIEFuseFillIntoForall.cpp"
"AMDAIEFusePackIntoLoop.cpp"
"AMDAIEHoistForAffineApply.cpp"
"AMDAIEHoistLogicalObjFifo.cpp"
"AMDAIEInsertCores.cpp"
"AMDAIEInsertLoopsForVectorization.cpp"
"AMDAIELinkExecutables.cpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ namespace mlir::iree_compiler::AMDAIE {
#define GEN_PASS_DEF_AMDAIEFUSEFILLINTOFORALL
#define GEN_PASS_DEF_AMDAIEFUSEPACKINTOLOOP
#define GEN_PASS_DEF_AMDAIEHOISTFORLOOPAFFINEAPPLY
#define GEN_PASS_DEF_AMDAIEHOISTLOGICALOBJFIFO
#define GEN_PASS_DEF_AMDAIEINSERTAIEWORKGROUP
#define GEN_PASS_DEF_AMDAIEINSERTCORES
#define GEN_PASS_DEF_AMDAIEINSERTLOOPSFORVECTORIZATION
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,7 @@ void addAMDAIEObjectFifoLoweringPasses(OpPassManager &passManager) {
passManager.addNestedPass<func::FuncOp>(createAMDAIECreateAIEWorkgroupPass());
passManager.addPass(createCSEPass());

passManager.addPass(createAMDAIEHoistLogicalObjFifoPass());
passManager.addPass(createAMDAIECanonicalizeDoublyStridedOpPass());
passManager.addPass(createAMDAIEFlattenLogicalObjectFifoPass());
passManager.addPass(createAMDAIEAssignLogicalObjectFifoDepthPass());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,10 @@ std::unique_ptr<Pass> createAMDAIEFuseFillIntoForallPass();
/// Hoist an affine.apply op on a scf.for op's induction variable.
std::unique_ptr<Pass> createAMDAIEHoistForLoopAffineApplyPass();

/// Create a pass to hoist logical objectFifo operations to the scope of its
/// operands.
std::unique_ptr<Pass> createAMDAIEHoistLogicalObjFifoPass();

/// Create a pass to transform linalg.generics into a form which benefits later
/// vectorization passes (to vector and aievec dialects).
std::unique_ptr<Pass> createAMDAIEInsertLoopsForVectorizationPass();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,11 @@ def AMDAIEHoistForLoopAffineApply : Pass<"iree-amdaie-hoist-for-affine-apply"> {
let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIEHoistForLoopAffineApplyPass()";
}

def AMDAIEHoistLogicalObjFifo : Pass<"iree-amdaie-hoist-logical-objectfifo"> {
let summary = "Hoist logical objectFifo operations to the scope of its operands.";
let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIEHoistLogicalObjFifoPass()";
}

def AMDAIEInsertCores :
Pass<"iree-amdaie-insert-cores", "ModuleOp"> {
let summary = "Insert `amdaie.core` operations inside the innermost "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ iree_lit_test_suite(
"fuse_fill_into_forall.mlir"
"fuse_pack_into_loop.mlir"
"hoist_for_affine_apply.mlir"
"hoist_logical_obj_fifo.mlir"
"insert_cores.mlir"
"insert_loops_for_vectorization.mlir"
"localize_logical_objectfifo.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(iree-amdaie-hoist-logical-objectfifo)" %s | FileCheck %s

// CHECK-LABEL: @func_hoist
// CHECK-SAME: %[[ARG0:.+]]: memref<32x64xi32>
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[TILE_0_0:.+]] = amdaie.tile(%[[C0]], %[[C0]])
// CHECK: amdaie.logicalobjectfifo.from_memref %[[ARG0]], {%[[TILE_0_0]]}
// CHECK: scf.forall
// CHECK-NOT: amdaie.logicalobjectfifo.from_memref
module {
func.func @func_hoist(%arg0: memref<32x64xi32>) {
%c0 = arith.constant 0 : index
%tile_0_0 = amdaie.tile(%c0, %c0)
scf.forall (%arg1, %arg2) in (1, 2) {
%obj0 = amdaie.logicalobjectfifo.from_memref %arg0, {%tile_0_0} : memref<32x64xi32> -> !amdaie.logicalobjectfifo<memref<32x64xi32>>
}
return
}
}


// -----

// CHECK-LABEL: @func_no_hoist
// CHECK-SAME: %[[ARG0:.+]]: memref<32x64xi32>
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: scf.forall
// CHECK: %[[TILE_0_0:.+]] = amdaie.tile(%[[C0]], %[[C0]])
// CHECK: amdaie.logicalobjectfifo.from_memref %[[ARG0]], {%[[TILE_0_0]]}
module {
func.func @func_no_hoist(%arg0: memref<32x64xi32>) {
%c0 = arith.constant 0 : index
scf.forall (%arg1, %arg2) in (1, 2) {
%tile_0_0 = amdaie.tile(%c0, %c0)
%obj0 = amdaie.logicalobjectfifo.from_memref %arg0, {%tile_0_0} : memref<32x64xi32> -> !amdaie.logicalobjectfifo<memref<32x64xi32>>
}
return
}
}

// -----

// CHECK-LABEL: @workgroup_hoist
// CHECK-SAME: %[[ARG0:.+]]: memref<32x64xi32>
// CHECK: amdaie.workgroup
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[TILE_0_0:.+]] = amdaie.tile(%[[C0]], %[[C0]])
// CHECK: amdaie.logicalobjectfifo.from_memref %[[ARG0]], {%[[TILE_0_0]]}
// CHECK: scf.forall
// CHECK-NOT: amdaie.logicalobjectfifo.from_memref
func.func @workgroup_hoist(%arg0: memref<32x64xi32>) {
amdaie.workgroup {
%c0 = arith.constant 0 : index
%tile_0_0 = amdaie.tile(%c0, %c0)
scf.forall (%arg1, %arg2) in (1, 2) {
%obj0 = amdaie.logicalobjectfifo.from_memref %arg0, {%tile_0_0} : memref<32x64xi32> -> !amdaie.logicalobjectfifo<memref<32x64xi32>>
}
amdaie.controlcode {
amdaie.end
}
}
return
}

// -----

// CHECK-LABEL: @workgroup_no_hoist
// CHECK-SAME: %[[ARG0:.+]]: memref<32x64xi32>
// CHECK: amdaie.workgroup
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: scf.forall
// CHECK: %[[TILE_0_0:.+]] = amdaie.tile(%[[C0]], %[[C0]])
// CHECK: amdaie.logicalobjectfifo.from_memref %[[ARG0]], {%[[TILE_0_0]]}
func.func @workgroup_no_hoist(%arg0: memref<32x64xi32>) {
amdaie.workgroup {
%c0 = arith.constant 0 : index
scf.forall (%arg1, %arg2) in (1, 2) {
%tile_0_0 = amdaie.tile(%c0, %c0)
%obj0 = amdaie.logicalobjectfifo.from_memref %arg0, {%tile_0_0} : memref<32x64xi32> -> !amdaie.logicalobjectfifo<memref<32x64xi32>>
}
amdaie.controlcode {
amdaie.end
}
}
return
}
// -----

// CHECK-LABEL: @workgroup_no_hoist_outside
// CHECK-SAME: %[[ARG0:.+]]: memref<32x64xi32>
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[TILE_0_0:.+]] = amdaie.tile(%[[C0]], %[[C0]])
// CHECK: amdaie.workgroup
// CHECK: amdaie.logicalobjectfifo.from_memref %[[ARG0]], {%[[TILE_0_0]]}
func.func @workgroup_no_hoist_outside(%arg0: memref<32x64xi32>) {
%c0 = arith.constant 0 : index
%tile_0_0 = amdaie.tile(%c0, %c0)
amdaie.workgroup {
%obj0 = amdaie.logicalobjectfifo.from_memref %arg0, {%tile_0_0} : memref<32x64xi32> -> !amdaie.logicalobjectfifo<memref<32x64xi32>>
amdaie.controlcode {
amdaie.end
}
}
return
}

// -----

// CHECK-LABEL: @controlcode_hoist
// CHECK-SAME: %[[ARG0:.+]]: memref<32x64xi32>
// CHECK: amdaie.controlcode
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[TILE_0_0:.+]] = amdaie.tile(%[[C0]], %[[C0]])
// CHECK: amdaie.logicalobjectfifo.from_memref %[[ARG0]], {%[[TILE_0_0]]}
// CHECK: scf.forall
// CHECK: scf.forall
// CHECK-NOT: amdaie.logicalobjectfifo.from_memref
func.func @controlcode_hoist(%arg0: memref<32x64xi32>) {
amdaie.workgroup {
amdaie.controlcode {
%c0 = arith.constant 0 : index
%tile_0_0 = amdaie.tile(%c0, %c0)
scf.forall (%arg1, %arg2) in (1, 2) {
scf.forall (%arg3, %arg4) in (1, 2) {
%obj0 = amdaie.logicalobjectfifo.from_memref %arg0, {%tile_0_0} : memref<32x64xi32> -> !amdaie.logicalobjectfifo<memref<32x64xi32>>
}
}
amdaie.end
}
}
return
}

// -----

// CHECK-LABEL: @controlcode_partial_hoist
// CHECK-SAME: %[[ARG0:.+]]: memref<32x64xi32>
// CHECK: amdaie.controlcode
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: scf.forall
// CHECK: %[[TILE_0_0:.+]] = amdaie.tile(%[[C0]], %[[C0]])
// CHECK: amdaie.logicalobjectfifo.from_memref %[[ARG0]], {%[[TILE_0_0]]}
// CHECK: scf.forall
// CHECK-NOT: amdaie.logicalobjectfifo.from_memref
func.func @controlcode_partial_hoist(%arg0: memref<32x64xi32>) {
amdaie.workgroup {
amdaie.controlcode {
%c0 = arith.constant 0 : index
scf.forall (%arg1, %arg2) in (1, 2) {
%tile_0_0 = amdaie.tile(%c0, %c0)
scf.forall (%arg3, %arg4) in (1, 2) {
%obj0 = amdaie.logicalobjectfifo.from_memref %arg0, {%tile_0_0} : memref<32x64xi32> -> !amdaie.logicalobjectfifo<memref<32x64xi32>>
}
}
amdaie.end
}
}
return
}

// -----

// CHECK-LABEL: @controlcode_no_hoist
// CHECK-SAME: %[[ARG0:.+]]: memref<32x64xi32>
// CHECK: amdaie.controlcode
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: scf.forall
// CHECK: scf.forall
// CHECK: %[[TILE_0_0:.+]] = amdaie.tile(%[[C0]], %[[C0]])
// CHECK: amdaie.logicalobjectfifo.from_memref %[[ARG0]], {%[[TILE_0_0]]}
func.func @controlcode_no_hoist(%arg0: memref<32x64xi32>) {
amdaie.workgroup {
amdaie.controlcode {
%c0 = arith.constant 0 : index
scf.forall (%arg1, %arg2) in (1, 2) {
scf.forall (%arg3, %arg4) in (1, 2) {
%tile_0_0 = amdaie.tile(%c0, %c0)
%obj0 = amdaie.logicalobjectfifo.from_memref %arg0, {%tile_0_0} : memref<32x64xi32> -> !amdaie.logicalobjectfifo<memref<32x64xi32>>
}
}
amdaie.end
}
}
return
}

// -----

// CHECK-LABEL: @controlcode_no_hoist_outside
// CHECK-SAME: %[[ARG0:.+]]: memref<32x64xi32>
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[TILE_0_0:.+]] = amdaie.tile(%[[C0]], %[[C0]])
// CHECK: amdaie.controlcode
// CHECK: amdaie.logicalobjectfifo.from_memref %[[ARG0]], {%[[TILE_0_0]]}
func.func @controlcode_no_hoist_outside(%arg0: memref<32x64xi32>) {
amdaie.workgroup {
%c0 = arith.constant 0 : index
%tile_0_0 = amdaie.tile(%c0, %c0)
amdaie.controlcode {
%obj0 = amdaie.logicalobjectfifo.from_memref %arg0, {%tile_0_0} : memref<32x64xi32> -> !amdaie.logicalobjectfifo<memref<32x64xi32>>
amdaie.end
}
}
return
}
2 changes: 1 addition & 1 deletion tests/samples/matmul_peeled_objectfifo.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: iree-opt --pass-pipeline="builtin.module(fold-memref-alias-ops,iree-amdaie-pack-to-dma,air-copy-to-dma,iree-amdaie-air-dma-to-amdaie-dma,iree-amdaie-insert-cores,cse,iree-amdaie-localize-logicalobjectfifo,iree-amdaie-distribute-cores-and-objectfifos,cse,canonicalize,iree-amdaie-dma-to-circular-dma,func.func(iree-amdaie-create-aie-workgroup),cse,iree-amdaie-canonicalize-doubly-strided-op,iree-amdaie-flatten-logicalobjectfifo,iree-amdaie-access-to-acquire-release,cse,canonicalize,iree-amdaie-dma-loop-subsumption,cse,canonicalize,iree-amdaie-assign-npu-dma-bd-ids,iree-amdaie-controlcode-loop-unroll,cse,canonicalize,iree-amdaie-create-logical-objectfifo-link,iree-amdaie-canonicalize-doubly-strided-op,iree-amdaie-lower-to-aie,canonicalize)" --split-input-file %s | FileCheck %s
// RUN: iree-opt --pass-pipeline="builtin.module(fold-memref-alias-ops,iree-amdaie-pack-to-dma,air-copy-to-dma,iree-amdaie-air-dma-to-amdaie-dma,iree-amdaie-insert-cores,cse,iree-amdaie-localize-logicalobjectfifo,iree-amdaie-distribute-cores-and-objectfifos,cse,canonicalize,iree-amdaie-dma-to-circular-dma,func.func(iree-amdaie-create-aie-workgroup),cse,iree-amdaie-hoist-logical-objectfifo,iree-amdaie-canonicalize-doubly-strided-op,iree-amdaie-flatten-logicalobjectfifo,iree-amdaie-access-to-acquire-release,cse,canonicalize,iree-amdaie-dma-loop-subsumption,cse,canonicalize,iree-amdaie-assign-npu-dma-bd-ids,iree-amdaie-controlcode-loop-unroll,cse,canonicalize,iree-amdaie-create-logical-objectfifo-link,iree-amdaie-canonicalize-doubly-strided-op,iree-amdaie-lower-to-aie,canonicalize)" --split-input-file %s | FileCheck %s

// CHECK: aie.device(npu1_4col)
// CHECK-DAG: %[[TILE_0_2:.+]] = aie.tile(0, 2)
Expand Down

0 comments on commit 907b52a

Please sign in to comment.