Skip to content

Commit

Permalink
Add support for unranked memref in airrt dialect (#895)
Browse files Browse the repository at this point in the history
  • Loading branch information
erwei-xilinx authored Feb 13, 2025
1 parent efc2a13 commit 04e396c
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
6 changes: 3 additions & 3 deletions mlir/include/air/Dialect/AIRRt/AIRRtOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def AIRRt_DmaMemcpyNdOp: AIRRt_Op<"dma_memcpy_nd", []> {
ins I32:$id,
I64:$x,
I64:$y,
AnyMemRef:$memref,
AnyRankedOrUnrankedMemRef:$memref,
I64:$offset3,
I64:$offset2,
I64:$offset1,
Expand Down Expand Up @@ -149,8 +149,8 @@ def AIRRt_DmaMemcpyNdOp: AIRRt_Op<"dma_memcpy_nd", []> {
def AIRRt_MemcpyNdOp: AIRRt_Op<"memcpy_nd", []> {
let summary = "dma operator";
let arguments = (
ins AnyMemRef:$dst,
AnyMemRef:$src,
ins AnyRankedOrUnrankedMemRef:$dst,
AnyRankedOrUnrankedMemRef:$src,
I64:$offset3,
I64:$offset2,
I64:$offset1,
Expand Down
18 changes: 10 additions & 8 deletions mlir/lib/Conversion/AIRLoweringPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,8 @@ class AIRSegmentConversion : public ConversionPattern {
rewriter.clone(o, remap);
} else if (auto chanOp = dyn_cast<air::ChannelInterface>(o)) {
// clone L3 get/put
MemRefType memrefTy =
llvm::cast<MemRefType>(chanOp.getMemref().getType());
BaseMemRefType memrefTy =
llvm::cast<BaseMemRefType>(chanOp.getMemref().getType());
if (memrefTy.getMemorySpaceAsInt() == (int)air::MemorySpace::L3) {
rewriter.clone(o, remap);
continue;
Expand Down Expand Up @@ -329,8 +329,10 @@ class AIRDmaMemcpyNdToAIRRtConversion
rewriter.create<airrt::WaitAllOp>(
op->getLoc(), airrt::EventType::get(op->getContext()), deps);

MemRefType src = llvm::cast<MemRefType>(op.getSrcMemref().getType());
MemRefType dst = llvm::cast<MemRefType>(op.getDstMemref().getType());
BaseMemRefType src =
llvm::cast<BaseMemRefType>(op.getSrcMemref().getType());
BaseMemRefType dst =
llvm::cast<BaseMemRefType>(op.getDstMemref().getType());
bool isFromTile = false;
bool isFullMemcpy = false;
if (src.getMemorySpaceAsInt() == (int)air::MemorySpace::L1 &&
Expand Down Expand Up @@ -455,8 +457,8 @@ AIRChannelInterfaceToAIRRtConversionImpl(OpBuilder builder,
auto loc = thisOp->getLoc();
auto ctx = thisOp->getContext();

MemRefType thisMemrefType =
llvm::cast<MemRefType>(thisOp.getMemref().getType());
BaseMemRefType thisMemrefType =
llvm::cast<BaseMemRefType>(thisOp.getMemref().getType());

bool thisOpIsInShim =
thisMemrefType.getMemorySpaceAsInt() == (int)xilinx::air::MemorySpace::L3;
Expand Down Expand Up @@ -640,7 +642,7 @@ class L2DeallocToAIRRtConversion : public ConversionPattern {
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto dealloc = cast<memref::DeallocOp>(op);
auto type = llvm::cast<MemRefType>(dealloc.getMemref().getType());
auto type = llvm::cast<BaseMemRefType>(dealloc.getMemref().getType());
if (type.getMemorySpaceAsInt() == (int)air::MemorySpace::L2) {
rewriter.replaceOpWithNewOp<airrt::DeallocOp>(op, SmallVector<Type>{},
op->getOperands());
Expand Down Expand Up @@ -965,7 +967,7 @@ class AIRLoweringPass : public air::impl::AIRLoweringBase<AIRLoweringPass> {
});

target.addDynamicallyLegalOp<memref::DeallocOp>([&](memref::DeallocOp op) {
return (llvm::cast<MemRefType>(op.getMemref().getType())
return (llvm::cast<BaseMemRefType>(op.getMemref().getType())
.getMemorySpaceAsInt() != (int)air::MemorySpace::L2);
});

Expand Down

0 comments on commit 04e396c

Please sign in to comment.