From 04e396c861c63cdd5e07d27bdbb166205fdeebb5 Mon Sep 17 00:00:00 2001 From: erwei-xilinx Date: Thu, 13 Feb 2025 14:05:24 -0800 Subject: [PATCH] Add support for unranked memref in airrt dialect (#895) --- mlir/include/air/Dialect/AIRRt/AIRRtOps.td | 6 +++--- mlir/lib/Conversion/AIRLoweringPass.cpp | 18 ++++++++++-------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/mlir/include/air/Dialect/AIRRt/AIRRtOps.td b/mlir/include/air/Dialect/AIRRt/AIRRtOps.td index d107c850a..0b1fdfcad 100644 --- a/mlir/include/air/Dialect/AIRRt/AIRRtOps.td +++ b/mlir/include/air/Dialect/AIRRt/AIRRtOps.td @@ -117,7 +117,7 @@ def AIRRt_DmaMemcpyNdOp: AIRRt_Op<"dma_memcpy_nd", []> { ins I32:$id, I64:$x, I64:$y, - AnyMemRef:$memref, + AnyRankedOrUnrankedMemRef:$memref, I64:$offset3, I64:$offset2, I64:$offset1, @@ -149,8 +149,8 @@ def AIRRt_DmaMemcpyNdOp: AIRRt_Op<"dma_memcpy_nd", []> { def AIRRt_MemcpyNdOp: AIRRt_Op<"memcpy_nd", []> { let summary = "dma operator"; let arguments = ( - ins AnyMemRef:$dst, - AnyMemRef:$src, + ins AnyRankedOrUnrankedMemRef:$dst, + AnyRankedOrUnrankedMemRef:$src, I64:$offset3, I64:$offset2, I64:$offset1, diff --git a/mlir/lib/Conversion/AIRLoweringPass.cpp b/mlir/lib/Conversion/AIRLoweringPass.cpp index 1a4fb7e51..6dc27eeb3 100644 --- a/mlir/lib/Conversion/AIRLoweringPass.cpp +++ b/mlir/lib/Conversion/AIRLoweringPass.cpp @@ -180,8 +180,8 @@ class AIRSegmentConversion : public ConversionPattern { rewriter.clone(o, remap); } else if (auto chanOp = dyn_cast(o)) { // clone L3 get/put - MemRefType memrefTy = - llvm::cast(chanOp.getMemref().getType()); + BaseMemRefType memrefTy = + llvm::cast(chanOp.getMemref().getType()); if (memrefTy.getMemorySpaceAsInt() == (int)air::MemorySpace::L3) { rewriter.clone(o, remap); continue; @@ -329,8 +329,10 @@ class AIRDmaMemcpyNdToAIRRtConversion rewriter.create( op->getLoc(), airrt::EventType::get(op->getContext()), deps); - MemRefType src = llvm::cast(op.getSrcMemref().getType()); - MemRefType dst = llvm::cast(op.getDstMemref().getType()); + BaseMemRefType src = + llvm::cast(op.getSrcMemref().getType()); + BaseMemRefType dst = + llvm::cast(op.getDstMemref().getType()); bool isFromTile = false; bool isFullMemcpy = false; if (src.getMemorySpaceAsInt() == (int)air::MemorySpace::L1 && @@ -455,8 +457,8 @@ AIRChannelInterfaceToAIRRtConversionImpl(OpBuilder builder, auto loc = thisOp->getLoc(); auto ctx = thisOp->getContext(); - MemRefType thisMemrefType = - llvm::cast(thisOp.getMemref().getType()); + BaseMemRefType thisMemrefType = + llvm::cast(thisOp.getMemref().getType()); bool thisOpIsInShim = thisMemrefType.getMemorySpaceAsInt() == (int)xilinx::air::MemorySpace::L3; @@ -640,7 +642,7 @@ class L2DeallocToAIRRtConversion : public ConversionPattern { matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto dealloc = cast(op); - auto type = llvm::cast(dealloc.getMemref().getType()); + auto type = llvm::cast(dealloc.getMemref().getType()); if (type.getMemorySpaceAsInt() == (int)air::MemorySpace::L2) { rewriter.replaceOpWithNewOp(op, SmallVector{}, op->getOperands()); @@ -965,7 +967,7 @@ class AIRLoweringPass : public air::impl::AIRLoweringBase { }); target.addDynamicallyLegalOp([&](memref::DeallocOp op) { - return (llvm::cast(op.getMemref().getType()) + return (llvm::cast(op.getMemref().getType()) .getMemorySpaceAsInt() != (int)air::MemorySpace::L2); });