Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for unranked memref #894

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mlir/include/air/Dialect/AIR/AIR.td
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ def air_ChannelPutOp : air_Op<"channel.put", [air_AsyncOpInterface,
Arguments<(ins Variadic<air_AsyncToken>:$async_dependencies,
FlatSymbolRefAttr:$chan_name,
Variadic<Index>:$indices,
AnyMemRef:$src,
AnyRankedOrUnrankedMemRef:$src,
Variadic<Index>:$src_offsets,
Variadic<Index>:$src_sizes,
Variadic<Index>:$src_strides)>,
Expand Down Expand Up @@ -480,7 +480,7 @@ def air_ChannelGetOp : air_Op<"channel.get", [air_AsyncOpInterface,
Arguments<(ins Variadic<air_AsyncToken>:$async_dependencies,
FlatSymbolRefAttr:$chan_name,
Variadic<Index>:$indices,
AnyMemRef:$dst,
AnyRankedOrUnrankedMemRef:$dst,
Variadic<Index>:$dst_offsets,
Variadic<Index>:$dst_sizes,
Variadic<Index>:$dst_strides)>,
Expand Down
14 changes: 8 additions & 6 deletions mlir/lib/Conversion/AIRToAIEPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ struct AIRToAIEConversionOptions {

// get memcpy operation volumn (elements) as int
int getMemcpySizesAsInt(Value memref, SmallVector<Value> sizes) {
MemRefType memTy = llvm::cast<MemRefType>(memref.getType());
BaseMemRefType memTy = llvm::cast<BaseMemRefType>(memref.getType());
if (sizes.empty())
return getTensorVolume(memTy);
else {
Expand Down Expand Up @@ -1029,8 +1029,8 @@ struct LowerAIRChannelsPattern : public OpRewritePattern<air::ChannelOp> {
return res;

// check if this put is linked to a get from another channel
MemRefType memref =
llvm::cast<MemRefType>(channelPuts[0].getMemref().getType());
BaseMemRefType memref =
llvm::cast<BaseMemRefType>(channelPuts[0].getMemref().getType());
int mem_space = memref.getMemorySpaceAsInt();
if (mem_space == (int)air::MemorySpace::L2) {
if (linksToComplete.find(channelPuts[0].getOperation()) !=
Expand Down Expand Up @@ -1070,7 +1070,8 @@ struct LowerAIRChannelsPattern : public OpRewritePattern<air::ChannelOp> {
consumers.push_back(consumerTile);

// check if this get is linked to a put from another channel
MemRefType memref = llvm::cast<MemRefType>(get.getMemref().getType());
BaseMemRefType memref =
llvm::cast<BaseMemRefType>(get.getMemref().getType());
int mem_space = memref.getMemorySpaceAsInt();
if (mem_space == (int)air::MemorySpace::L2) {
if (linksToComplete.find(get.getOperation()) != linksToComplete.end()) {
Expand Down Expand Up @@ -1226,7 +1227,7 @@ struct LowerAIRChannelsPattern : public OpRewritePattern<air::ChannelOp> {
AIE::ObjectFifoCreateOp objFifo,
AIE::ObjectFifoPort port,
llvm::SmallSet<Operation *, 2> &erased_allocs) const {
MemRefType memref = cast<MemRefType>(op.getMemref().getType());
BaseMemRefType memref = cast<BaseMemRefType>(op.getMemref().getType());
int mem_space = memref.getMemorySpaceAsInt();
if (mem_space == (int)air::MemorySpace::L2) {
// add alloc to list of ops to erase
Expand Down Expand Up @@ -1260,7 +1261,8 @@ struct LowerAIRChannelsPattern : public OpRewritePattern<air::ChannelOp> {
PatternRewriter &rewriter, MyOp op, AIE::ObjectFifoCreateOp objFifo,
AIE::ObjectFifoPort port,
llvm::SmallSet<Operation *, 2> &erased_deallocs) const {
MemRefType memref = llvm::cast<MemRefType>(op.getMemref().getType());
BaseMemRefType memref =
llvm::cast<BaseMemRefType>(op.getMemref().getType());
int mem_space = memref.getMemorySpaceAsInt();
if (mem_space == (int)air::MemorySpace::L2) {
return;
Expand Down
22 changes: 12 additions & 10 deletions mlir/lib/Conversion/AIRToAIESchedulingUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ using namespace xilinx;
bool air::isTileInbound(air::MemcpyInterface memcpyOp, int tileMemSpaceAsInt) {
if (memcpyOp.getSrcMemref() && memcpyOp.getDstMemref()) {
int src_memory_space =
llvm::cast<MemRefType>(memcpyOp.getSrcMemref().getType())
llvm::cast<BaseMemRefType>(memcpyOp.getSrcMemref().getType())
.getMemorySpaceAsInt();
int dst_memory_space =
llvm::cast<MemRefType>(memcpyOp.getDstMemref().getType())
llvm::cast<BaseMemRefType>(memcpyOp.getDstMemref().getType())
.getMemorySpaceAsInt();
assert(src_memory_space !=
dst_memory_space); // air.dmaMemcpyNdOp isn't meant to represent
Expand Down Expand Up @@ -319,7 +319,7 @@ air::getLockValuePair(const AIE::AIETargetModel &targetModel,

// Infer semaphore lock values using buffer op
// TODO: What if a buffer memref is read or written by multiple channels?
if (!llvm::isa<MemRefType>(buffer_memref.getType()))
if (!llvm::isa<BaseMemRefType>(buffer_memref.getType()))
return std::make_pair(-1, -1);
int read_counter = 0;
int write_counter = 0;
Expand Down Expand Up @@ -360,7 +360,7 @@ air::getLockValuePair(const AIE::AIETargetModel &targetModel,
if (!targetModel.hasProperty(AIE::AIETargetModel::UsesSemaphoreLocks))
return std::make_pair(0, 0);

if (!llvm::isa<MemRefType>(buffer_memref.getType()))
if (!llvm::isa<BaseMemRefType>(buffer_memref.getType()))
return std::make_pair(-1, -1);

if (!air_chan)
Expand Down Expand Up @@ -984,11 +984,11 @@ void MemcpyBundleAsFlow::pushBackMemcpyOpToBundle(air::DmaMemcpyNdOp memcpyOp) {
// air::DmaMemcpyNdOp is a complete memcpy with both src and dst
S2MM[0].push_back(memcpyOp.getOperation());
S2MM_memspace_as_int =
llvm::cast<MemRefType>(memcpyOp.getDstMemref().getType())
llvm::cast<BaseMemRefType>(memcpyOp.getDstMemref().getType())
.getMemorySpaceAsInt();
MM2S.push_back(memcpyOp.getOperation());
MM2S_memspace_as_int =
llvm::cast<MemRefType>(memcpyOp.getSrcMemref().getType())
llvm::cast<BaseMemRefType>(memcpyOp.getSrcMemref().getType())
.getMemorySpaceAsInt();
}

Expand Down Expand Up @@ -1019,16 +1019,18 @@ void MemcpyBundleAsFlow::pushBackMemcpyOpToBundle(air::ChannelGetOp memcpyOp) {
}
air_flow_op = chan.getOperation();
S2MM[alloc_id].push_back(memcpyOp.getOperation());
S2MM_memspace_as_int = llvm::cast<MemRefType>(memcpyOp.getMemref().getType())
.getMemorySpaceAsInt();
S2MM_memspace_as_int =
llvm::cast<BaseMemRefType>(memcpyOp.getMemref().getType())
.getMemorySpaceAsInt();
}

void MemcpyBundleAsFlow::pushBackMemcpyOpToBundle(air::ChannelPutOp memcpyOp) {
auto chan = air::getChannelDeclarationThroughSymbol(memcpyOp);
air_flow_op = chan.getOperation();
MM2S.push_back(memcpyOp.getOperation());
MM2S_memspace_as_int = llvm::cast<MemRefType>(memcpyOp.getMemref().getType())
.getMemorySpaceAsInt();
MM2S_memspace_as_int =
llvm::cast<BaseMemRefType>(memcpyOp.getMemref().getType())
.getMemorySpaceAsInt();
}

void MemcpyBundleAsFlow::pushBackMemcpyOpToBundle(
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Transform/AIRDependency.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ class AIRDependency : public air::impl::AIRDependencyBase<AIRDependency> {
else {
bool isCandidateExecute = false;
for (auto operand : op->getOperands()) {
if (llvm::isa<MemRefType>(operand.getType()) ||
if (llvm::isa<BaseMemRefType>(operand.getType()) ||
llvm::isa<IndexType>(operand.getType())) {
isCandidateExecute = true;
}
Expand Down Expand Up @@ -798,7 +798,7 @@ class AIRDependency : public air::impl::AIRDependencyBase<AIRDependency> {
}

char checkOperandReadOrWrite(mlir::Value operand) {
if (!llvm::isa<MemRefType>(operand.getType())) {
if (!llvm::isa<BaseMemRefType>(operand.getType())) {
operand.getDefiningOp()->emitOpError(
"operand being traced is not a memref");
}
Expand Down Expand Up @@ -843,7 +843,7 @@ class AIRDependency : public air::impl::AIRDependencyBase<AIRDependency> {
template <typename T>
void pushDepsAtCurrentScope(mlir::Value operand, T op, char rw = 'n',
partialMemref *tile = nullptr) {
if (!llvm::isa<MemRefType>(operand.getType())) {
if (!llvm::isa<BaseMemRefType>(operand.getType())) {
operand.getDefiningOp()->emitOpError(
"operand being traced is not a memref");
}
Expand Down
15 changes: 8 additions & 7 deletions mlir/lib/Transform/AIRDependencyScheduleOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2389,7 +2389,7 @@ struct UnrollChannelByFactorPattern {
// Update memref size (divide by factor)
SmallVector<Value, 1> new_sizes = op.getSizes();
if (new_sizes.empty()) {
auto memTy = llvm::cast<MemRefType>(op.getMemref().getType());
auto memTy = llvm::cast<BaseMemRefType>(op.getMemref().getType());
for (auto d : getTensorShape(memTy)) {
new_sizes.push_back(
builder.create<arith::ConstantIndexOp>(par->getLoc(), d));
Expand Down Expand Up @@ -3721,15 +3721,15 @@ class AIRFuseChannels
bool hitsMemorySpaceForAggMode(std::vector<air::ChannelPutOp> &puts,
std::vector<air::ChannelGetOp> &gets) {
for (auto put : puts) {
MemRefType ty = llvm::cast<MemRefType>(put.getMemref().getType());
BaseMemRefType ty = llvm::cast<BaseMemRefType>(put.getMemref().getType());
if (llvm::any_of(targetMemorySpaces, [&](unsigned memSpace) {
return memSpace == ty.getMemorySpaceAsInt();
})) {
return true;
}
}
for (auto get : gets) {
MemRefType ty = llvm::cast<MemRefType>(get.getMemref().getType());
BaseMemRefType ty = llvm::cast<BaseMemRefType>(get.getMemref().getType());
if (llvm::any_of(targetMemorySpaces, [&](unsigned memSpace) {
return memSpace == ty.getMemorySpaceAsInt();
})) {
Expand Down Expand Up @@ -4596,9 +4596,10 @@ struct ShrinkMemrefSizesByAccessPattern
}

// Replace memref alloc op;
Type elemType = llvm::cast<MemRefType>(memref.getType()).getElementType();
Type elemType =
llvm::cast<BaseMemRefType>(memref.getType()).getElementType();
Attribute memorySpace =
llvm::cast<MemRefType>(memref.getType()).getMemorySpace();
llvm::cast<BaseMemRefType>(memref.getType()).getMemorySpace();
auto newMemrefType = MemRefType::get(overall_access_bounds, elemType,
nullptr, memorySpace);
if (auto execOp = dyn_cast<air::ExecuteOp>(alloc->getParentOp())) {
Expand Down Expand Up @@ -4740,10 +4741,10 @@ struct ShrinkMemrefSizesByAccessPattern
auto static_strides = subViewOp.getStaticStrides();
auto static_offsets = subViewOp.getStaticOffsets();
// Get MemRefType after shrinkage.
Type elemType = llvm::cast<MemRefType>(subViewOp.getSource().getType())
Type elemType = llvm::cast<BaseMemRefType>(subViewOp.getSource().getType())
.getElementType();
Attribute memorySpace =
llvm::cast<MemRefType>(subViewOp.getSource().getType())
llvm::cast<BaseMemRefType>(subViewOp.getSource().getType())
.getMemorySpace();
auto shrunkMemrefType =
MemRefType::get(overall_access_bounds, elemType, nullptr, memorySpace);
Expand Down
31 changes: 16 additions & 15 deletions mlir/lib/Transform/AIRDmaToChannel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -379,8 +379,8 @@ static void replaceAIRDmaWithAIRChannelPairs(
auto dst = op.getDstMemref();
auto ctx = op->getContext();

auto src_type = llvm::dyn_cast<MemRefType>(src.getType());
auto dst_type = llvm::dyn_cast<MemRefType>(dst.getType());
auto src_type = llvm::dyn_cast<BaseMemRefType>(src.getType());
auto dst_type = llvm::dyn_cast<BaseMemRefType>(dst.getType());
SmallVector<Value, 4> src_offsets = op.getSrcOffsets();
SmallVector<Value, 4> dst_offsets = op.getDstOffsets();
SmallVector<Value, 4> src_sizes = op.getSrcSizes();
Expand Down Expand Up @@ -605,16 +605,16 @@ LogicalResult HoistingAffineIf(affine::AffineIfOp op) {
// Label dependent ops to hoist
for (auto b : backwardSlice) {
b->setAttr("hoist", StringAttr::get(ctx, "dep"));
if (dyn_cast<air::ExecuteOp>(b)) {
auto child_op = &(*b->getRegions().front().op_begin());
child_op->setAttr("hoist", StringAttr::get(ctx, "dep"));
if (auto exec = dyn_cast<air::ExecuteOp>(b)) {
auto &child_op = exec.getChildOps().front();
child_op.setAttr("hoist", StringAttr::get(ctx, "dep"));
}
}

// Hoist hierarchy op into scf op
module_builder.setInsertionPoint(hier_op);
MemRefType externalMemrefTy =
llvm::cast<MemRefType>(externalGetPut[0].getMemref().getType());
auto externalMemrefTy =
llvm::dyn_cast<BaseMemRefType>(externalGetPut[0].getMemref().getType());
if (externalMemrefTy.getMemorySpaceAsInt() == (int)air::MemorySpace::L3 &&
segment) {
module_builder.setInsertionPoint(segment);
Expand Down Expand Up @@ -769,8 +769,8 @@ class AIRDmaToAIRChannelConversion
auto ctx = op->getContext();

// It must already be a memref
auto src_type = llvm::dyn_cast<MemRefType>(src.getType());
auto dst_type = llvm::dyn_cast<MemRefType>(dst.getType());
auto src_type = llvm::dyn_cast<BaseMemRefType>(src.getType());
auto dst_type = llvm::dyn_cast<BaseMemRefType>(dst.getType());
if (!src_type)
return failure();

Expand Down Expand Up @@ -1068,7 +1068,7 @@ static LogicalResult AIRDemoteMemrefToAIRHierarchy(
auto memref =
isa<air::ExecuteOp>(op) ? op->getResult(1) : op->getResult(0);
auto token = isa<air::ExecuteOp>(op) ? op->getResult(0) : nullptr;
auto memref_type = llvm::dyn_cast<MemRefType>(memref.getType());
auto memref_type = llvm::dyn_cast<BaseMemRefType>(memref.getType());

if (memref_type.getMemorySpaceAsInt() == hierMemorySpace)
continue; // Alloc op is already under correct hierarchy
Expand Down Expand Up @@ -1153,8 +1153,8 @@ class AIRDemoteDmaToAIRHierarchyConversion
auto ctx = op->getContext();

// It must already be a memref
auto src_type = llvm::dyn_cast<MemRefType>(src.getType());
auto dst_type = llvm::dyn_cast<MemRefType>(dst.getType());
auto src_type = llvm::dyn_cast<BaseMemRefType>(src.getType());
auto dst_type = llvm::dyn_cast<BaseMemRefType>(dst.getType());
if (!src_type)
return failure();

Expand Down Expand Up @@ -1391,7 +1391,8 @@ struct DmaToChannelPass : public air::impl::DmaToChannelBase<DmaToChannelPass> {
std::map<air::HierarchyInterface, std::vector<Operation *>> hier_to_allocs;
for (auto f : funcOps) {
f.walk([&](memref::AllocOp alloc) {
auto memref_type = dyn_cast<MemRefType>(alloc.getMemref().getType());
auto memref_type =
dyn_cast<BaseMemRefType>(alloc.getMemref().getType());
int hierMemorySpace = (int)air::MemorySpace::L3;
air::HierarchyInterface hier_op =
alloc->getParentOfType<air::HierarchyInterface>();
Expand Down Expand Up @@ -1439,9 +1440,9 @@ struct DmaToChannelPass : public air::impl::DmaToChannelBase<DmaToChannelPass> {
target_0.addDynamicallyLegalOp<air::DmaMemcpyNdOp>(
[&](air::DmaMemcpyNdOp dma) {
auto src_type =
llvm::dyn_cast<MemRefType>(dma.getSrcMemref().getType());
llvm::dyn_cast<BaseMemRefType>(dma.getSrcMemref().getType());
auto dst_type =
llvm::dyn_cast<MemRefType>(dma.getDstMemref().getType());
llvm::dyn_cast<BaseMemRefType>(dma.getDstMemref().getType());
if (dma->getParentOfType<air::HerdOp>()) {
if (src_type.getMemorySpaceAsInt() < (int)air::MemorySpace::L1 &&
dst_type.getMemorySpaceAsInt() < (int)air::MemorySpace::L1)
Expand Down
14 changes: 8 additions & 6 deletions mlir/lib/Util/Dependency.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1062,7 +1062,7 @@ getAllReadAccessedMemrefOperandsFromOp(Operation *op) {
return entry;
};
auto pushMemrefEntryToVector = [](auto entry, auto &vector) {
if (!isa<MemRefType>(entry.first.getType()))
if (!isa<BaseMemRefType>(entry.first.getType()))
return;
vector.push_back(entry);
};
Expand Down Expand Up @@ -1127,7 +1127,7 @@ getAllWriteAccessedMemrefOperandsFromOp(Operation *op) {
return entry;
};
auto pushMemrefEntryToVector = [](auto entry, auto &vector) {
if (!isa<MemRefType>(entry.first.getType()))
if (!isa<BaseMemRefType>(entry.first.getType()))
return;
vector.push_back(entry);
};
Expand Down Expand Up @@ -1597,7 +1597,8 @@ Graph::VertexId dependencyCanonicalizer::addVertexFromExecuteOp(
// Annotate memref's memory space
std::string memorySpaceStr =
getMemorySpaceAsString(alloc_child_op.getMemref());
auto ty = llvm::cast<MemRefType>(alloc_child_op.getMemref().getType());
auto ty =
llvm::cast<BaseMemRefType>(alloc_child_op.getMemref().getType());
detailed_description += "(" + memorySpaceStr + ", " +
std::to_string(getTensorVolume(ty)) + ", " +
getElementTypeAsString(ty) + ")";
Expand All @@ -1609,7 +1610,8 @@ Graph::VertexId dependencyCanonicalizer::addVertexFromExecuteOp(
// Annotate memref's memory space
std::string memorySpaceStr =
getMemorySpaceAsString(dealloc_child_op.getMemref());
auto ty = llvm::cast<MemRefType>(dealloc_child_op.getMemref().getType());
auto ty =
llvm::cast<BaseMemRefType>(dealloc_child_op.getMemref().getType());
detailed_description += "(" + memorySpaceStr + ", " +
std::to_string(getTensorVolume(ty)) + ", " +
getElementTypeAsString(ty) + ")";
Expand Down Expand Up @@ -2321,7 +2323,7 @@ void dependencyCanonicalizer::redoDepTraceIfDepOnHier(func::FuncOp func) {
void dependencyTracer::pushDepsAtCurrentScope(mlir::Value operand,
air::AsyncOpInterface op, char rw,
partialMemref *tile) {
if (!llvm::isa<MemRefType>(operand.getType()))
if (!llvm::isa<BaseMemRefType>(operand.getType()))
op->emitOpError("operand being traced is not a memref");
for (auto &u : operand.getUses()) {
// If used in MemcpyInterface Op
Expand Down Expand Up @@ -2531,7 +2533,7 @@ bool dependencyTracer::areEqualIndexPartialMemrefs(partialMemref *tile_0,
}

char dependencyTracer::checkOperandReadOrWrite(mlir::Value operand) {
if (!llvm::isa<MemRefType>(operand.getType()))
if (!llvm::isa<BaseMemRefType>(operand.getType()))
operand.getDefiningOp()->emitOpError(
"operand being traced is not a memref");
bool foundWriteAccess = false;
Expand Down
Loading
Loading