Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIRWrapFuncWithParallel: add a new pass that wraps a funcOp body with a scf.parallel loop with specified loop bounds #898

Merged
merged 2 commits into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mlir/include/air/Conversion/ConvertToAIRPass.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ createParallelToSegmentPass(const ParallelToSegmentOptions &options);
std::unique_ptr<mlir::Pass> createCopyToDmaPass();
std::unique_ptr<mlir::Pass> createInsertEmptyLaunchOverHerdPass();

std::unique_ptr<Pass> createAIRWrapFuncWithParallelPass();
std::unique_ptr<mlir::Pass>
createAIRWrapFuncWithParallelPass(AIRWrapFuncWithParallelPassOptions options);

} // namespace air
} // namespace xilinx

Expand Down
1 change: 1 addition & 0 deletions mlir/include/air/Conversion/PassDetail.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ using namespace mlir;
#define GEN_PASS_DEF_PARALLELTOHERD
#define GEN_PASS_DEF_PARALLELTOLAUNCH
#define GEN_PASS_DEF_PARALLELTOSEGMENT
#define GEN_PASS_DEF_AIRWRAPFUNCWITHPARALLELPASS
#include "air/Conversion/Passes.h.inc"

} // namespace air
Expand Down
17 changes: 17 additions & 0 deletions mlir/include/air/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -430,4 +430,21 @@ def InsertEmptyLaunchOverHerd : Pass<"air-insert-launch-and-segment-around-herd"
}];
}

def AIRWrapFuncWithParallelPass : Pass<"air-wrap-func-with-parallel", "func::FuncOp"> {
let summary = "Wraps the body of a given func.func operation inside an scf.parallel loop.";
let constructor = "xilinx::air::createAIRWrapFuncWithParallelPass()";
let description = [{
Wraps the body of a given func.func operation inside an scf.parallel loop.
The pass assumes that:
(1) The function arguments consist of: M memref arguments, N loop upper bounds, N loop induction variable indices.
(2) The scf.parallel loop is constructed using the N upper bounds and induction variable indices.
(3) The scf.parallel loop is inserted at the beginning of the function, wrapping all existing operations.
}];
let options = [
ListOption<"clLoopBounds", "loop-bounds", "unsigned",
"Specify upper bounds for scf.parallel loops",
"llvm::cl::ZeroOrMore">,
];
}

#endif
217 changes: 217 additions & 0 deletions mlir/lib/Conversion/ConvertToAIRPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1465,6 +1465,215 @@ struct InsertEmptyLaunchOverHerdPass
}
};



// Identifies arith operations where all operands are either constants, or
// produced by IndexCastOp casting from IndexType. If detected, canonicalize
// IndexCast ops by changing the arith op's input/output types to IndexType.
template <typename T>
LogicalResult canonicalizeArithBinaryOpToIndexType(T arithOp,
RewriterBase &rewriter) {
Value lhs = arithOp.getLhs();
Value rhs = arithOp.getRhs();

SmallVector<Value> inVals = {lhs, rhs};
if (llvm::all_of(inVals, [](Value v) { return isa<IndexType>(v.getType()); }))
return failure();
if (llvm::any_of(inVals, [](Value v) {
if (!v.getDefiningOp())
return true;
if (getConstantIntValue(v))
return false;
else if (auto castOp = dyn_cast_if_present<arith::IndexCastOp>(
v.getDefiningOp())) {
if (llvm::all_of(castOp->getOperands(), [](Value oper) {
return isa<IndexType>(oper.getType());
}))
return false;
else
return true;
}
return true;
}))
return failure();

auto loc = arithOp.getLoc();
if (!isa<IndexType>(lhs.getType())) {
lhs =
rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), lhs);
}
if (!isa<IndexType>(rhs.getType())) {
rhs =
rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), rhs);
}
auto newArithOp = rewriter.create<T>(loc, rewriter.getIndexType(), lhs, rhs);
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(
arithOp, arithOp.getResult().getType(), newArithOp);

return success();
}

struct CanonicalizeArithAddIOpToIndexTypePattern
: public OpRewritePattern<arith::AddIOp> {
using OpRewritePattern<arith::AddIOp>::OpRewritePattern;

LogicalResult matchAndRewrite(arith::AddIOp arithOp,
PatternRewriter &rewriter) const override {
return canonicalizeArithBinaryOpToIndexType<arith::AddIOp>(arithOp,
rewriter);
}
};
struct CanonicalizeArithMulIOpToIndexTypePattern
: public OpRewritePattern<arith::MulIOp> {
using OpRewritePattern<arith::MulIOp>::OpRewritePattern;

LogicalResult matchAndRewrite(arith::MulIOp arithOp,
PatternRewriter &rewriter) const override {
return canonicalizeArithBinaryOpToIndexType<arith::MulIOp>(arithOp,
rewriter);
}
};

// Wraps the body of a given func.func operation inside an scf.parallel loop.
// The pass assumes that:
// (1) The function arguments consist of: M memref arguments, N loop upper
// bounds, N loop induction variable indices. (2) The scf.parallel loop is
// constructed using the N upper bounds and induction variable indices. (3) The
// scf.parallel loop is inserted at the beginning of the function, wrapping all
// existing operations.

struct WrapFuncWithParallelPattern : public OpRewritePattern<func::FuncOp> {
using OpRewritePattern<func::FuncOp>::OpRewritePattern;

WrapFuncWithParallelPattern(MLIRContext *context,
SmallVector<int64_t> &bounds)
: OpRewritePattern(context), loopBounds(bounds) {}

LogicalResult matchAndRewrite(func::FuncOp funcOp,
PatternRewriter &rewriter) const override {
if (funcOp.isExternal())
return failure(); // Ignore external functions

if (loopBounds.empty()) {
funcOp.emitError("Pass option 'loop-bounds' must be specified.");
return failure();
}

unsigned N = loopBounds.size(); // Number of loop dimensions

// Get function arguments
auto args = funcOp.getArguments();
unsigned numArgs = args.size();

if (numArgs < 2) {
funcOp.emitError(
"Expected at least 2 arguments: memrefs and loop bounds.");
return failure();
}

// Determine M (memrefs count)
unsigned M = 0;
for (Type argType : funcOp.getFunctionType().getInputs()) {
if (isa<UnrankedMemRefType, MemRefType>(argType))
M++;
else
break;
}
if (M + N * 2 > numArgs) {
funcOp.emitError("Expected func op arguments contain at least M memrefs "
"and N x 2 loop bounds.");
return failure();
}

// Extract indices
ValueRange inductionVars = args.slice(M + N, N);

if (llvm::all_of(inductionVars, [](Value iv) { return iv.use_empty(); }))
return failure();

// Store original function body operations
SmallVector<Operation *> originalFuncBodyOps;
for (auto &op : funcOp.getBody().front().without_terminator())
originalFuncBodyOps.push_back(&op);

// Create scf.parallel loop
Location loc = funcOp.getLoc();
rewriter.setInsertionPointToStart(&funcOp.getBody().front());
SmallVector<Value, 4> lowerBounds(
N, rewriter.create<arith::ConstantIndexOp>(loc, 0));
SmallVector<Value, 4> steps(
N, rewriter.create<arith::ConstantIndexOp>(loc, 1));
SmallVector<Value, 4> upperBoundsVals;
for (int64_t bound : loopBounds) {
upperBoundsVals.push_back(
rewriter.create<arith::ConstantIndexOp>(loc, bound));
}

auto parallelOp = rewriter.create<scf::ParallelOp>(loc, lowerBounds,
upperBoundsVals, steps);

// Redirect arguments properly inside the loop
Block &loopBlock = parallelOp.getRegion().front();
rewriter.setInsertionPointToStart(&loopBlock);
IRMapping remap;
for (unsigned i = 0; i < N; i++) {
Value loopBlockArg = loopBlock.getArgument(i);
if (inductionVars[i].getType() != loopBlockArg.getType())
loopBlockArg = rewriter.create<arith::IndexCastOp>(
loc, inductionVars[i].getType(), loopBlockArg);
remap.map(inductionVars[i], loopBlockArg);
}

// Move function body into the loop
for (auto op : originalFuncBodyOps) {
rewriter.clone(*op, remap);
}

// Erase original function body ops
for (auto o : llvm::reverse(originalFuncBodyOps))
rewriter.eraseOp(o);

return success();
}

private:
SmallVector<int64_t> &loopBounds; // External loop bounds
};

class AIRWrapFuncWithParallelPass
: public air::impl::AIRWrapFuncWithParallelPassBase<
AIRWrapFuncWithParallelPass> {

public:
AIRWrapFuncWithParallelPass() = default;
AIRWrapFuncWithParallelPass(const AIRWrapFuncWithParallelPass &pass){};
AIRWrapFuncWithParallelPass(
const ::xilinx::air::AIRWrapFuncWithParallelPassOptions &options)
: AIRWrapFuncWithParallelPassBase(options) {}

void runOnOperation() override;

private:
};

void AIRWrapFuncWithParallelPass::runOnOperation() {
func::FuncOp funcOp = getOperation();
MLIRContext *context = &getContext();

RewritePatternSet wrapParPatterns(context);
SmallVector<int64_t> loopBoundsVec;
for (auto i : clLoopBounds)
loopBoundsVec.push_back(i);
wrapParPatterns.add<WrapFuncWithParallelPattern>(context, loopBoundsVec);
(void)applyOpPatternsGreedily(SmallVector<Operation *>{funcOp.getOperation()},
std::move(wrapParPatterns));

RewritePatternSet patterns(context);
patterns.add<CanonicalizeArithAddIOpToIndexTypePattern,
CanonicalizeArithMulIOpToIndexTypePattern>(context);
(void)applyPatternsGreedily(funcOp, std::move(patterns));
}

} // namespace

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1591,5 +1800,13 @@ std::unique_ptr<mlir::Pass> createInsertEmptyLaunchOverHerdPass() {
return std::make_unique<InsertEmptyLaunchOverHerdPass>();
}

std::unique_ptr<Pass> createAIRWrapFuncWithParallelPass() {
return std::make_unique<AIRWrapFuncWithParallelPass>();
}
std::unique_ptr<Pass>
createAIRWrapFuncWithParallelPass(AIRWrapFuncWithParallelPassOptions options) {
return std::make_unique<AIRWrapFuncWithParallelPass>(options);
}

} // namespace air
} // namespace xilinx
38 changes: 38 additions & 0 deletions mlir/test/Conversion/ConvertToAIR/wrap_func_with_parallel.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
//===- wrap_func_with_parallel.mlir ----------------------------*- MLIR -*-===//
//
// Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
// SPDX-License-Identifier: MIT
//
//===----------------------------------------------------------------------===//

// RUN: air-opt -air-wrap-func-with-parallel='loop-bounds=4,4,1' %s | FileCheck %s
// CHECK-LABEL: @func0
// CHECK: scf.parallel (%[[ARG0:.*]], %[[ARG1:.*]], %[[ARG2:.*]]) = (%c0{{.*}}, %c0{{.*}}, %c0{{.*}}) to (%c4{{.*}}, %c4{{.*}}, %c1{{.*}}) step (%c1{{.*}}, %c1{{.*}}, %c1{{.*}})
// CHECK: arith.index_cast %[[ARG0]] : index to i32
// CHECK: arith.index_cast %[[ARG1]] : index to i32

func.func @func0(%arg0: memref<*xf32>, %arg1: memref<*xf32>, %arg2: memref<*xf32>, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) {
%cst = arith.constant 0.000000e+00 : f32
%c64 = arith.constant 64 : index
%c32_i32 = arith.constant 32 : i32
%0 = arith.muli %arg6, %c32_i32 : i32
%1 = arith.index_cast %0 : i32 to index
%2 = arith.muli %arg7, %c32_i32 : i32
%3 = arith.index_cast %2 : i32 to index
%4 = arith.muli %1, %c64 : index
%reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [%4], sizes: [32, 64], strides: [%c64, 1] : memref<*xf32> to memref<32x64xf32, strided<[?, 1], offset: ?>>
%alloc = memref.alloc() : memref<32x64xf32>
memref.copy %reinterpret_cast, %alloc : memref<32x64xf32, strided<[?, 1], offset: ?>> to memref<32x64xf32>
%5 = bufferization.to_tensor %alloc restrict writable : memref<32x64xf32> to tensor<32x64xf32>
%reinterpret_cast_0 = memref.reinterpret_cast %arg1 to offset: [%3], sizes: [64, 32], strides: [%c64, 1] : memref<*xf32> to memref<64x32xf32, strided<[?, 1], offset: ?>>
%alloc_1 = memref.alloc() : memref<64x32xf32>
memref.copy %reinterpret_cast_0, %alloc_1 : memref<64x32xf32, strided<[?, 1], offset: ?>> to memref<64x32xf32>
%6 = bufferization.to_tensor %alloc_1 restrict writable : memref<64x32xf32> to tensor<64x32xf32>
%7 = tensor.empty() : tensor<32x32xf32>
%8 = linalg.fill ins(%cst : f32) outs(%7 : tensor<32x32xf32>) -> tensor<32x32xf32>
%9 = linalg.matmul ins(%5, %6 : tensor<32x64xf32>, tensor<64x32xf32>) outs(%8 : tensor<32x32xf32>) -> tensor<32x32xf32>
%10 = arith.addi %4, %3 : index
%reinterpret_cast_2 = memref.reinterpret_cast %arg2 to offset: [%10], sizes: [32, 32], strides: [%c64, 1] : memref<*xf32> to memref<32x32xf32, strided<[?, 1], offset: ?>>
bufferization.materialize_in_destination %9 in writable %reinterpret_cast_2 : (tensor<32x32xf32>, memref<32x32xf32, strided<[?, 1], offset: ?>>) -> ()
return
}
Loading