Skip to content

Commit

Permalink
Rewrite BufferMemrefToFuncArgs as pattern and fold greedily (#853)
Browse files Browse the repository at this point in the history
  • Loading branch information
erwei-xilinx authored Jan 10, 2025
1 parent c78cf3f commit ad90c73
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 55 deletions.
4 changes: 4 additions & 0 deletions mlir/include/air/Util/Util.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,10 @@ SmallVector<Operation *> cloneDefiningOpsInRegion(OpBuilder builder,
SmallVectorImpl<Value> &opers,
IRMapping &remap);

// Buffer all allocations of memref directly within the func op's body into the
// func op's arguments.
void populateBufferMemrefToFuncArgsPattern(RewritePatternSet &patterns);

} // namespace air
} // namespace xilinx

Expand Down
58 changes: 3 additions & 55 deletions mlir/lib/Conversion/AIRRtToNpuPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1025,11 +1025,10 @@ struct AIRRtToNpuPass : public impl::AIRRtToNpuBase<AIRRtToNpuPass> {
// Unroll any affine for loops
unrollAffineFors(module);

// Buffer npu.dma_memcpy_nd memref to function's argument list.
BufferMemrefToFuncArgs(module);

// Cast buffers to i32 types
// Cast buffers to i32 types; buffer npu.dma_memcpy_nd memref to function's
// argument list.
RewritePatternSet castPattern(ctx);
air::populateBufferMemrefToFuncArgsPattern(castPattern);
castPattern.add(CastFunctionArgs);
(void)applyPatternsAndFoldGreedily(module, std::move(castPattern));

Expand Down Expand Up @@ -1631,57 +1630,6 @@ struct AIRRtToNpuPass : public impl::AIRRtToNpuBase<AIRRtToNpuPass> {
chanToIdMap[col]++));
});
}

// Buffers npu.dma_memcpy_op memref as function argument
void BufferMemrefToFuncArgs(ModuleOp module) {
module.walk([&](mlir::func::FuncOp f) { BufferMemrefToFuncArgs(f); });
}
void BufferMemrefToFuncArgs(func::FuncOp funcOp) {
if (!funcOp)
return;

// Collect illegal dma ops whose memrefs are not in function's arguments.
SmallVector<Type, 6> memrefTypes;
SmallVector<Value, 6> memrefs;
funcOp.walk([&](AIEX::NpuDmaMemcpyNdOp dma) {
Value memref = dma.getMemref();
auto args = funcOp.getArguments();
// if the memref is an arg, return
if (std::find(args.begin(), args.end(), memref) != args.end())
return;
// if the memref is the result of a cast of an arg, return
if (auto cast = dyn_cast_or_null<UnrealizedConversionCastOp>(
memref.getDefiningOp())) {
if (std::find(args.begin(), args.end(), cast.getOperand(0)) !=
args.end())
return;
else
memref = cast.getOperand(0);
}
// push back if unique
if (std::find(memrefs.begin(), memrefs.end(), memref) == memrefs.end()) {
memrefs.push_back(memref);
memrefTypes.push_back(memref.getType());
}
});

// Append memref to function's arguments.
auto functionType = funcOp.getFunctionType();
auto newArgTypes = llvm::to_vector<6>(
llvm::concat<const Type>(functionType.getInputs(), memrefTypes));
auto newFunctionType = FunctionType::get(funcOp.getContext(), newArgTypes,
functionType.getResults());
funcOp.setType(newFunctionType);

// Add the new arguments to the entry block if the function is not external.
if (!funcOp.isExternal()) {
Location loc = funcOp.getLoc();
for (Value v : memrefs) {
auto newArg = funcOp.front().addArgument(v.getType(), loc);
v.replaceAllUsesWith(newArg);
}
}
}
};

} // namespace
Expand Down
54 changes: 54 additions & 0 deletions mlir/lib/Util/Util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1693,3 +1693,57 @@ air::cloneDefiningOpsInRegion(OpBuilder builder, Region *region,
clonedOps.push_back(builder.clone(*op, remap));
return clonedOps;
}

// Buffer all allocations of L3 memref directly within the func op's body into
// the func op's arguments.
struct BufferMemrefToFuncArgsPattern : public OpRewritePattern<func::FuncOp> {
using OpRewritePattern<func::FuncOp>::OpRewritePattern;

LogicalResult matchAndRewrite(func::FuncOp funcOp,
PatternRewriter &rewriter) const override {

if (funcOp.isExternal())
return failure();

SmallVector<Type, 6> memrefTypes;
llvm::SetVector<Value> memrefs;
for (auto &op : funcOp.getFunctionBody().getOps()) {
if (isa<CastOpInterface>(op))
continue;
for (auto res : op.getResults()) {
MemRefType resType = dyn_cast<MemRefType>(res.getType());
if (!resType)
continue;
if (resType.getMemorySpaceAsInt() == (int)air::MemorySpace::L3)
memrefs.insert(res);
}
}
for (auto memref : memrefs)
memrefTypes.push_back(memref.getType());
if (memrefs.empty())
return failure();

// Append memref to function's arguments.
auto functionType = funcOp.getFunctionType();
auto newArgTypes = llvm::to_vector<6>(
llvm::concat<const Type>(functionType.getInputs(), memrefTypes));
auto newFunctionType = FunctionType::get(funcOp.getContext(), newArgTypes,
functionType.getResults());
funcOp.setType(newFunctionType);

// Add the new arguments to the entry block if the function is not external.
Location loc = funcOp.getLoc();
for (Value v : memrefs) {
auto newArg = funcOp.front().addArgument(v.getType(), loc);
v.replaceAllUsesWith(newArg);
}
return success();
}

private:
};

void air::populateBufferMemrefToFuncArgsPattern(RewritePatternSet &patterns) {
MLIRContext *ctx = patterns.getContext();
patterns.insert<BufferMemrefToFuncArgsPattern>(ctx);
}

0 comments on commit ad90c73

Please sign in to comment.