From 82796c38b93906b6d3c2539d5e36f75f5fb9c385 Mon Sep 17 00:00:00 2001 From: erwei-xilinx Date: Wed, 22 Jan 2025 17:04:54 -0800 Subject: [PATCH 1/6] Add findCommonRegionContainingAllAncestors method to Util --- mlir/include/air/Util/Util.h | 5 +++++ mlir/lib/Util/Util.cpp | 30 ++++++++++++++++++++++++++---- 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/mlir/include/air/Util/Util.h b/mlir/include/air/Util/Util.h index a918e3a1f..270917945 100644 --- a/mlir/include/air/Util/Util.h +++ b/mlir/include/air/Util/Util.h @@ -255,6 +255,11 @@ void getBackwardSliceInRegion(OpBuilder builder, Region *region, // func op's arguments. void populateBufferMemrefToFuncArgsPattern(RewritePatternSet &patterns); +// Find a common region that contains all ops, or ancestors of ops, until a +// specified region. +Region *findCommonRegionContainingAllAncestors(SmallVector ops, + Operation *until); + } // namespace air } // namespace xilinx diff --git a/mlir/lib/Util/Util.cpp b/mlir/lib/Util/Util.cpp index 4644b7121..4a4284ef8 100644 --- a/mlir/lib/Util/Util.cpp +++ b/mlir/lib/Util/Util.cpp @@ -1375,8 +1375,10 @@ air::writeAccessPattern(mlir::vector::TransferReadOp readOp) { OpBuilder builder(readOp); std::tuple, SmallVector, SmallVector> pattern; - [[maybe_unused]] auto vectorTy = llvm::cast(readOp.getVector().getType()); - [[maybe_unused]] auto memrefTy = llvm::cast(readOp.getSource().getType()); + [[maybe_unused]] auto vectorTy = + llvm::cast(readOp.getVector().getType()); + [[maybe_unused]] auto memrefTy = + llvm::cast(readOp.getSource().getType()); assert(vectorTy && "Not a vector"); assert(memrefTy && "Not a memref"); // Initialize wraps and strides based on the unshrunk memref shape. @@ -1400,8 +1402,10 @@ air::writeAccessPattern(mlir::vector::TransferWriteOp writeOp) { OpBuilder builder(writeOp); std::tuple, SmallVector, SmallVector> pattern; - [[maybe_unused]] auto memrefTy = llvm::cast(writeOp.getSource().getType()); - [[maybe_unused]] auto vectorTy = llvm::cast(writeOp.getVector().getType()); + [[maybe_unused]] auto memrefTy = + llvm::cast(writeOp.getSource().getType()); + [[maybe_unused]] auto vectorTy = + llvm::cast(writeOp.getVector().getType()); assert(memrefTy && "Not a memref"); assert(vectorTy && "Not a vector"); // Initialize wraps and strides based on the unshrunk memref shape. @@ -1735,3 +1739,21 @@ void air::populateBufferMemrefToFuncArgsPattern(RewritePatternSet &patterns) { MLIRContext *ctx = patterns.getContext(); patterns.insert(ctx); } + +// Find a common region that contains all ops, or ancestors of ops, until it +// reaches a specified parent op. +Region * +air::findCommonRegionContainingAllAncestors(SmallVector ops, + Operation *until) { + Region *region = ops.front()->getParentRegion(); + while (llvm::any_of(ops, [region](Operation *op) { + return !region->isAncestor(op->getParentRegion()); + })) { + // Failed to find any shared region within the parent IsolatedFromAbove op + // body. + if (until && region->getParentOp() == until) + return nullptr; + region = region->getParentRegion(); + } + return region; +} From 288119e6b67f6c8d0d11b21ee5fcda7c69214256 Mon Sep 17 00:00:00 2001 From: erwei-xilinx Date: Wed, 22 Jan 2025 17:05:31 -0800 Subject: [PATCH 2/6] Apply findCommonRegionContainingAllAncestors method --- .../Transform/AIRDependencyScheduleOpt.cpp | 25 ++++++++----------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp b/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp index e6ef52cf6..b6266e5bd 100644 --- a/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp +++ b/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp @@ -836,12 +836,12 @@ struct HoistAIRHerdsToSharedRegionPattern PatternRewriter &rewriter) const override { auto parentRegion = herdOp->getParentRegion(); auto symName = herdOp.getSymName(); - SmallVector herdsWithSameName; + SmallVector herdsWithSameName; parentRegion->walk>( [symName, &herdsWithSameName](air::HerdOp herd) { if (herd.getSymName() != symName) return WalkResult::skip(); - herdsWithSameName.push_back(herd); + herdsWithSameName.push_back(herd.getOperation()); return WalkResult::advance(); }); if (herdOp != herdsWithSameName.back()) @@ -849,19 +849,15 @@ struct HoistAIRHerdsToSharedRegionPattern // Get the innermost region that is ancestor to all herds sharing the same // name - Region *region = herdOp->getParentRegion(); - while (llvm::any_of(herdsWithSameName, [region](air::HerdOp h) { - return !region->isAncestor(h->getParentRegion()); - })) { - // Failed to find any shared region within the parent IsolatedFromAbove op - // body. - if (region->getParentOp()->mightHaveTrait()) - return failure(); - region = region->getParentRegion(); - } + Region *region = air::findCommonRegionContainingAllAncestors( + herdsWithSameName, + herdsWithSameName.front() + ->getParentWithTrait()); + if (!region) + return failure(); // If none of the herds are directly contained in region, then abort herd // hoisting; otherwise the loop hoisting breaks the IR functionality. - if (llvm::none_of(herdsWithSameName, [region](air::HerdOp h) { + if (llvm::none_of(herdsWithSameName, [region](Operation *h) { return h->getParentRegion() == region; })) return failure(); @@ -876,7 +872,8 @@ struct HoistAIRHerdsToSharedRegionPattern // Hoist herds to the shared parent region SmallVector processed, unprocessed; for (auto h : herdsWithSameName) { - auto newHerd = hoistAIRHerdInForImpl(h, region, rewriter); + auto newHerd = + hoistAIRHerdInForImpl(dyn_cast(h), region, rewriter); if (succeeded(newHerd)) { rewriter.eraseOp(h); processed.push_back(*newHerd); From b375a83321b3c7ee2b11400c57bcffdf400fd3d6 Mon Sep 17 00:00:00 2001 From: erwei-xilinx Date: Wed, 22 Jan 2025 17:13:03 -0800 Subject: [PATCH 3/6] Switch to using DominanceInfo for constucting dependency edges --- mlir/lib/Transform/AIRDependency.cpp | 221 +++++++++++++-------------- 1 file changed, 106 insertions(+), 115 deletions(-) diff --git a/mlir/lib/Transform/AIRDependency.cpp b/mlir/lib/Transform/AIRDependency.cpp index 33b40561d..2437f97b7 100644 --- a/mlir/lib/Transform/AIRDependency.cpp +++ b/mlir/lib/Transform/AIRDependency.cpp @@ -764,7 +764,9 @@ class AIRDependency : public air::impl::AIRDependencyBase { void pushDefiningOpAsDep(Value operand, T op) { // Check memref deps if (auto defop = operand.getDefiningOp()) { - if (foundAsyncOpUsesAboveCurrentLine(&defop)) { + DominanceInfo domInfo(defop); + if (domInfo.properlyDominates(defop, op)) { + // if (foundAsyncOpUsesAboveCurrentLine(&defop)) { addAsyncDepToGraphIfNew(defop.getResult(0), op); } } @@ -777,7 +779,9 @@ class AIRDependency : public air::impl::AIRDependencyBase { // If tile_index is not a nullptr // If created by async_region if (auto defop = tile_index.getDefiningOp()) { - if (foundAsyncOpUsesAboveCurrentLine(&defop)) { + DominanceInfo domInfo(defop); + if (domInfo.properlyDominates(defop, op)) { + // if (foundAsyncOpUsesAboveCurrentLine(&defop)) { addAsyncDepToGraphIfNew(defop.getResult(0), op); } } @@ -843,67 +847,62 @@ class AIRDependency : public air::impl::AIRDependencyBase { operand.getDefiningOp()->emitOpError( "operand being traced is not a memref"); } + auto opOrAncestorIsDominantOver = [](Operation *a, Operation *b) { + Region *commonRegion = air::findCommonRegionContainingAllAncestors( + SmallVector{a, b}, nullptr); + auto aAncestor = commonRegion->findAncestorOpInRegion(*a); + auto bAncestor = commonRegion->findAncestorOpInRegion(*b); + if (!aAncestor || !bAncestor) + return false; + DominanceInfo domInfo(aAncestor); + return domInfo.properlyDominates(aAncestor, bAncestor); + }; for (auto &u : operand.getUses()) { + if (!opOrAncestorIsDominantOver(u.getOwner(), op)) + continue; // If used in MemcpyInterface Op if (auto memcpy = dyn_cast(u.getOwner())) { - bool isUsedAboveThisLine = false; - if (auto dma = dyn_cast(u.getOwner())) { - if (foundAsyncOpUsesAboveCurrentLine( - &dma)) { // If this use is above current line - isUsedAboveThisLine = true; - } - } else if (auto channel = - dyn_cast(u.getOwner())) { - if (foundAsyncOpUsesAboveCurrentLine( - &channel)) { // If this use is above current line - isUsedAboveThisLine = true; - } + partialMemref memcpy_src, memcpy_dst; + if (memcpy.getSrcMemref()) { + memcpy_src = + partialMemref(memcpy.getSrcMemref(), memcpy.getSrcOffsets(), + memcpy.getSrcSizes(), memcpy.getSrcStrides()); + } + if (memcpy.getDstMemref()) { + memcpy_dst = + partialMemref(memcpy.getDstMemref(), memcpy.getDstOffsets(), + memcpy.getDstSizes(), memcpy.getDstStrides()); } - if (isUsedAboveThisLine) { - partialMemref memcpy_src, memcpy_dst; - if (memcpy.getSrcMemref()) { - memcpy_src = - partialMemref(memcpy.getSrcMemref(), memcpy.getSrcOffsets(), - memcpy.getSrcSizes(), memcpy.getSrcStrides()); - } - if (memcpy.getDstMemref()) { - memcpy_dst = - partialMemref(memcpy.getDstMemref(), memcpy.getDstOffsets(), - memcpy.getDstSizes(), memcpy.getDstStrides()); + if (rw == 'r') { + if (u.is(memcpy.getSrcMemref())) { + if (tile == nullptr) { + addAsyncDepToGraphIfNew(memcpy.getOperation()->getResult(0), + op); + } else if (areEqualIndexPartialMemrefs(tile, &memcpy_src)) + addAsyncDepToGraphIfNew(memcpy.getOperation()->getResult(0), + op); } - - if (rw == 'r') { - if (u.is(memcpy.getSrcMemref())) { - if (tile == nullptr) { - addAsyncDepToGraphIfNew(memcpy.getOperation()->getResult(0), - op); - } else if (areEqualIndexPartialMemrefs(tile, &memcpy_src)) - addAsyncDepToGraphIfNew(memcpy.getOperation()->getResult(0), - op); - } - } else if (rw == 'w') { - if (u.is(memcpy.getDstMemref())) { - if (tile == nullptr) { - addAsyncDepToGraphIfNew(memcpy.getOperation()->getResult(0), - op); - } else if (areEqualIndexPartialMemrefs(tile, &memcpy_dst)) - addAsyncDepToGraphIfNew(memcpy.getOperation()->getResult(0), - op); - } - } else { + } else if (rw == 'w') { + if (u.is(memcpy.getDstMemref())) { if (tile == nullptr) { addAsyncDepToGraphIfNew(memcpy.getOperation()->getResult(0), op); - } else if (u.is(memcpy.getDstMemref())) { - if (areEqualIndexPartialMemrefs(tile, &memcpy_dst)) - addAsyncDepToGraphIfNew(memcpy.getOperation()->getResult(0), - op); - } else if (u.is(memcpy.getSrcMemref())) { - if (areEqualIndexPartialMemrefs(tile, &memcpy_src)) - addAsyncDepToGraphIfNew(memcpy.getOperation()->getResult(0), - op); - } + } else if (areEqualIndexPartialMemrefs(tile, &memcpy_dst)) + addAsyncDepToGraphIfNew(memcpy.getOperation()->getResult(0), + op); + } + } else { + if (tile == nullptr) { + addAsyncDepToGraphIfNew(memcpy.getOperation()->getResult(0), op); + } else if (u.is(memcpy.getDstMemref())) { + if (areEqualIndexPartialMemrefs(tile, &memcpy_dst)) + addAsyncDepToGraphIfNew(memcpy.getOperation()->getResult(0), + op); + } else if (u.is(memcpy.getSrcMemref())) { + if (areEqualIndexPartialMemrefs(tile, &memcpy_src)) + addAsyncDepToGraphIfNew(memcpy.getOperation()->getResult(0), + op); } } } @@ -911,37 +910,33 @@ class AIRDependency : public air::impl::AIRDependencyBase { // If used in a linalg op else if (auto linalgop = mlir::dyn_cast(u.getOwner())) { if (auto ar = dyn_cast(linalgop->getParentOp())) { - if (foundAsyncOpUsesAboveCurrentLine(&ar)) { - if (rw == 'r') { - if (u.getOperandNumber() < - linalgop.getNumDpsInputs() + linalgop.getNumDpsInits()) - addAsyncDepToGraphIfNew(ar.getResult(0), op); - } else if (rw == 'w') { - if (u.getOperandNumber() >= linalgop.getNumDpsInputs() && - u.getOperandNumber() - linalgop.getNumDpsInputs() < - linalgop.getNumDpsInits()) - addAsyncDepToGraphIfNew(ar.getResult(0), op); - } else { + if (rw == 'r') { + if (u.getOperandNumber() < + linalgop.getNumDpsInputs() + linalgop.getNumDpsInits()) addAsyncDepToGraphIfNew(ar.getResult(0), op); - } + } else if (rw == 'w') { + if (u.getOperandNumber() >= linalgop.getNumDpsInputs() && + u.getOperandNumber() - linalgop.getNumDpsInputs() < + linalgop.getNumDpsInits()) + addAsyncDepToGraphIfNew(ar.getResult(0), op); + } else { + addAsyncDepToGraphIfNew(ar.getResult(0), op); } } } // If used in hierarchy op else if (auto hier = dyn_cast(u.getOwner())) { - if (foundAsyncOpUsesAboveCurrentLine(&hier)) { - // check if the use inside hierarchy op matches with the tracing mode - // (r or w) - for (unsigned hier_argument_id = 0; - hier_argument_id < hier.getNumKernelOperands(); - hier_argument_id++) { - if (u.is(hier.getKernelOperand(hier_argument_id))) { - auto child_op = hier.getKernelArgument(hier_argument_id); - char rw_check = checkOperandReadOrWrite(child_op); - if (rw == 'n' || rw_check == rw) { - addAsyncDepToGraphIfNew(hier->getResult(0), op); - } + // check if the use inside hierarchy op matches with the tracing mode (r + // or w) + for (unsigned hier_argument_id = 0; + hier_argument_id < hier.getNumKernelOperands(); + hier_argument_id++) { + if (u.is(hier.getKernelOperand(hier_argument_id))) { + auto child_op = hier.getKernelArgument(hier_argument_id); + char rw_check = checkOperandReadOrWrite(child_op); + if (rw == 'n' || rw_check == rw) { + addAsyncDepToGraphIfNew(hier->getResult(0), op); } } } @@ -949,12 +944,8 @@ class AIRDependency : public air::impl::AIRDependencyBase { // If used in an unknown op else { - auto unknownop = u.getOwner(); - if (auto ar = dyn_cast(unknownop->getParentOp())) { - if (foundAsyncOpUsesAboveCurrentLine(&ar)) { - addAsyncDepToGraphIfNew(ar.getResult(0), op); - } - } + if (auto ar = dyn_cast(u.getOwner()->getParentOp())) + addAsyncDepToGraphIfNew(ar.getResult(0), op); } } } @@ -1649,37 +1640,37 @@ class AIRDependency : public air::impl::AIRDependencyBase { // Other utilities //===----------------------------------------------------------------------===// - bool foundAsyncOpUsesAboveCurrentLine(air::ExecuteOp *op) { - if (!async_execute_op_history.empty()) - for (auto &iter : async_execute_op_history) - if (iter.getResult(0) == op->getResult(0)) - return true; - return false; - } - - bool foundAsyncOpUsesAboveCurrentLine(air::DmaMemcpyNdOp *op) { - if (!dma_op_history.empty()) - for (auto &iter : dma_op_history) - if (iter->getResult(0) == op->getOperation()->getResult(0)) - return true; - return false; - } - - bool foundAsyncOpUsesAboveCurrentLine(air::ChannelInterface *op) { - if (!channel_op_history.empty()) - for (auto &iter : channel_op_history) - if (iter->getResult(0) == op->getOperation()->getResult(0)) - return true; - return false; - } - - bool foundAsyncOpUsesAboveCurrentLine(air::HierarchyInterface *op) { - if (!hier_op_history.empty()) - for (auto &iter : hier_op_history) - if (iter->getResult(0) == op->getOperation()->getResult(0)) - return true; - return false; - } + // bool foundAsyncOpUsesAboveCurrentLine(air::ExecuteOp *op) { + // if (!async_execute_op_history.empty()) + // for (auto &iter : async_execute_op_history) + // if (iter.getResult(0) == op->getResult(0)) + // return true; + // return false; + // } + + // bool foundAsyncOpUsesAboveCurrentLine(air::DmaMemcpyNdOp *op) { + // if (!dma_op_history.empty()) + // for (auto &iter : dma_op_history) + // if (iter->getResult(0) == op->getOperation()->getResult(0)) + // return true; + // return false; + // } + + // bool foundAsyncOpUsesAboveCurrentLine(air::ChannelInterface *op) { + // if (!channel_op_history.empty()) + // for (auto &iter : channel_op_history) + // if (iter->getResult(0) == op->getOperation()->getResult(0)) + // return true; + // return false; + // } + + // bool foundAsyncOpUsesAboveCurrentLine(air::HierarchyInterface *op) { + // if (!hier_op_history.empty()) + // for (auto &iter : hier_op_history) + // if (iter->getResult(0) == op->getOperation()->getResult(0)) + // return true; + // return false; + // } // Check if two partial memref tiles have identical access patterns bool areEqualIndexPartialMemrefs(partialMemref *tile_0, From 10333a9c58ac5e202c014c4f7186f827ccb1859f Mon Sep 17 00:00:00 2001 From: erwei-xilinx Date: Wed, 22 Jan 2025 17:51:12 -0800 Subject: [PATCH 4/6] Purge global op_history vectors; reimplement getXXXOpFromVertex using map --- mlir/lib/Transform/AIRDependency.cpp | 105 +++++++++++---------------- 1 file changed, 42 insertions(+), 63 deletions(-) diff --git a/mlir/lib/Transform/AIRDependency.cpp b/mlir/lib/Transform/AIRDependency.cpp index 2437f97b7..617e5c9d4 100644 --- a/mlir/lib/Transform/AIRDependency.cpp +++ b/mlir/lib/Transform/AIRDependency.cpp @@ -250,6 +250,7 @@ class AIRDependency : public air::impl::AIRDependencyBase { // 2nd traversal: trace deps among async execute regions; build a dep graph + llvm::DenseMap, Operation *> opIdToOpMap; for (auto f : module.getOps()) { f.walk([&](Operation *op) { Operation *sink_op = nullptr; @@ -290,14 +291,15 @@ class AIRDependency : public air::impl::AIRDependencyBase { async_execute_op); // Keep track of processed async execute region ops. Deps should point // to the past, not future. - async_execute_op_history.push_back(async_execute_op); + opIdToOpMap[std::make_pair("execute", async_execute_op.getId())] = + async_execute_op; } else if (auto dma_op = mlir::dyn_cast(op)) { traceDeps(sink_op_memref_reads, dma_op, "RAW"); traceDeps(sink_op_memref_writes, dma_op, "WAW/WAR"); traceTileIndices(sink_op_memref_reads, sink_op_memref_writes, sink_op_scalar_ins, sink_op_scalar_outs, dma_op); - dma_op_history.push_back(dma_op); + opIdToOpMap[std::make_pair("dma", dma_op.getId())] = dma_op; } else if (auto channel_op = mlir::dyn_cast(op)) { traceDeps(sink_op_memref_reads, channel_op, @@ -306,9 +308,10 @@ class AIRDependency : public air::impl::AIRDependencyBase { "WAW/WAR"); traceTileIndices(sink_op_memref_reads, sink_op_memref_writes, sink_op_scalar_ins, sink_op_scalar_outs, channel_op); - channel_op_history.push_back(channel_op); + opIdToOpMap[std::make_pair("channel", channel_op.getId())] = + channel_op; } else if (auto hier_op = dyn_cast(op)) { - hier_op_history.push_back(hier_op); + opIdToOpMap[std::make_pair("hierarchy", hier_op.getId())] = hier_op; } }); } @@ -324,19 +327,22 @@ class AIRDependency : public air::impl::AIRDependencyBase { f.walk([&](Operation *op) { // Fill dep list of air execute ops if (auto async_execute_op = dyn_cast(op)) { - fillAIRDepListUsingGraphTR(async_execute_op); + fillAIRDepListUsingGraphTR(async_execute_op, + opIdToOpMap); } // Fill dep list of air dmamemcpy2d ops else if (auto dma_op = dyn_cast(op)) { - fillAIRDepListUsingGraphTR(dma_op); + fillAIRDepListUsingGraphTR(dma_op, opIdToOpMap); } // Fill dep list of air channel ops else if (auto channel_op = dyn_cast(op)) { - fillAIRDepListUsingGraphTR(channel_op); + fillAIRDepListUsingGraphTR(channel_op, + opIdToOpMap); } // Fill dep list of air hierarchy ops else if (auto hier_op = dyn_cast(op)) { - fillAIRDepListUsingGraphTR(hier_op); + fillAIRDepListUsingGraphTR(hier_op, + opIdToOpMap); } }); } @@ -453,12 +459,6 @@ class AIRDependency : public air::impl::AIRDependencyBase { // Creating async events //===----------------------------------------------------------------------===// - // Air async op history - std::vector async_execute_op_history; - std::vector dma_op_history; - std::vector channel_op_history; - std::vector hier_op_history; - // Create air execute op with async interface (no ssa result returned); update // graph air::ExecuteOp createAsyncExecute(OpBuilder &builder, Operation *op, @@ -1508,28 +1508,36 @@ class AIRDependency : public air::impl::AIRDependencyBase { wa_to_g; // Map between air wait_all and vertices in graph // g vertex to air op mapping - air::ExecuteOp getExecuteOpFromVertex(ExecuteGraph::VertexId v, - ExecuteGraph g) { + air::ExecuteOp getExecuteOpFromVertex( + ExecuteGraph::VertexId v, ExecuteGraph g, + llvm::DenseMap, Operation *> &opIdToOpMap) { assert(g[v].asyncEventType == "execute" && "This vertex is not a ExecuteOp"); - return async_execute_op_history[g[v].operationId - 1]; + auto op = opIdToOpMap[std::make_pair("execute", g[v].operationId)]; + return dyn_cast_if_present(op); } - air::DmaMemcpyNdOp getDmaOpFromVertex(ExecuteGraph::VertexId v, - ExecuteGraph g) { + air::DmaMemcpyNdOp getDmaOpFromVertex( + ExecuteGraph::VertexId v, ExecuteGraph g, + llvm::DenseMap, Operation *> &opIdToOpMap) { assert(g[v].asyncEventType == "dma" && "This vertex is not a DmaMemcpy op"); - return dma_op_history[g[v].operationId - 1]; + auto op = opIdToOpMap[std::make_pair("dma", g[v].operationId)]; + return dyn_cast_if_present(op); } - air::ChannelInterface getChannelOpFromVertex(ExecuteGraph::VertexId v, - ExecuteGraph g) { + air::ChannelInterface getChannelOpFromVertex( + ExecuteGraph::VertexId v, ExecuteGraph g, + llvm::DenseMap, Operation *> &opIdToOpMap) { assert(g[v].asyncEventType == "channel" && "This vertex is not a Channel op"); - return channel_op_history[g[v].operationId - 1]; + auto op = opIdToOpMap[std::make_pair("channel", g[v].operationId)]; + return dyn_cast_if_present(op); } - air::HierarchyInterface getHierOpFromVertex(ExecuteGraph::VertexId v, - ExecuteGraph g) { + air::HierarchyInterface getHierOpFromVertex( + ExecuteGraph::VertexId v, ExecuteGraph g, + llvm::DenseMap, Operation *> &opIdToOpMap) { assert(g[v].asyncEventType == "hierarchy" && "This vertex is not a Hierarchy op"); - return hier_op_history[g[v].operationId - 1]; + auto op = opIdToOpMap[std::make_pair("hierarchy", g[v].operationId)]; + return dyn_cast_if_present(op); } // air execute op to g vertex mapping @@ -1575,7 +1583,9 @@ class AIRDependency : public air::impl::AIRDependencyBase { // Fill in dep list of air async ops using graph tr's connectivity template - void fillAIRDepListUsingGraphTR(T op) { + void fillAIRDepListUsingGraphTR( + T op, + llvm::DenseMap, Operation *> &opIdToOpMap) { if (auto async_op = mlir::dyn_cast(op.getOperation())) { uint64_t dstTRVertex = getGraphGVertexFromAIROp(op); @@ -1583,20 +1593,21 @@ class AIRDependency : public air::impl::AIRDependencyBase { asyncExecuteGraph.inverseAdjacentVertices(dstTRVertex)) { if (asyncExecuteGraph[TRVertex].asyncEventType == "execute") async_op.addAsyncDependency( - getExecuteOpFromVertex(TRVertex, asyncExecuteGraph).getResult(0)); + getExecuteOpFromVertex(TRVertex, asyncExecuteGraph, opIdToOpMap) + .getResult(0)); else if (asyncExecuteGraph[TRVertex].asyncEventType == "dma") async_op.addAsyncDependency( - getDmaOpFromVertex(TRVertex, asyncExecuteGraph) + getDmaOpFromVertex(TRVertex, asyncExecuteGraph, opIdToOpMap) .getOperation() ->getResult(0)); else if (asyncExecuteGraph[TRVertex].asyncEventType == "channel") async_op.addAsyncDependency( - getChannelOpFromVertex(TRVertex, asyncExecuteGraph) + getChannelOpFromVertex(TRVertex, asyncExecuteGraph, opIdToOpMap) .getOperation() ->getResult(0)); else if (asyncExecuteGraph[TRVertex].asyncEventType == "hierarchy") async_op.addAsyncDependency( - getHierOpFromVertex(TRVertex, asyncExecuteGraph) + getHierOpFromVertex(TRVertex, asyncExecuteGraph, opIdToOpMap) .getOperation() ->getResult(0)); else @@ -1640,38 +1651,6 @@ class AIRDependency : public air::impl::AIRDependencyBase { // Other utilities //===----------------------------------------------------------------------===// - // bool foundAsyncOpUsesAboveCurrentLine(air::ExecuteOp *op) { - // if (!async_execute_op_history.empty()) - // for (auto &iter : async_execute_op_history) - // if (iter.getResult(0) == op->getResult(0)) - // return true; - // return false; - // } - - // bool foundAsyncOpUsesAboveCurrentLine(air::DmaMemcpyNdOp *op) { - // if (!dma_op_history.empty()) - // for (auto &iter : dma_op_history) - // if (iter->getResult(0) == op->getOperation()->getResult(0)) - // return true; - // return false; - // } - - // bool foundAsyncOpUsesAboveCurrentLine(air::ChannelInterface *op) { - // if (!channel_op_history.empty()) - // for (auto &iter : channel_op_history) - // if (iter->getResult(0) == op->getOperation()->getResult(0)) - // return true; - // return false; - // } - - // bool foundAsyncOpUsesAboveCurrentLine(air::HierarchyInterface *op) { - // if (!hier_op_history.empty()) - // for (auto &iter : hier_op_history) - // if (iter->getResult(0) == op->getOperation()->getResult(0)) - // return true; - // return false; - // } - // Check if two partial memref tiles have identical access patterns bool areEqualIndexPartialMemrefs(partialMemref *tile_0, partialMemref *tile_1) { From 66fd2ffcc3c1817a2633bd4f8475fac5ac6a1f58 Mon Sep 17 00:00:00 2001 From: erwei-xilinx Date: Wed, 22 Jan 2025 18:23:18 -0800 Subject: [PATCH 5/6] Use default util --- mlir/include/air/Util/Util.h | 2 +- mlir/lib/Util/Util.cpp | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/mlir/include/air/Util/Util.h b/mlir/include/air/Util/Util.h index 270917945..e62ec01c5 100644 --- a/mlir/include/air/Util/Util.h +++ b/mlir/include/air/Util/Util.h @@ -258,7 +258,7 @@ void populateBufferMemrefToFuncArgsPattern(RewritePatternSet &patterns); // Find a common region that contains all ops, or ancestors of ops, until a // specified region. Region *findCommonRegionContainingAllAncestors(SmallVector ops, - Operation *until); + Operation *until = nullptr); } // namespace air } // namespace xilinx diff --git a/mlir/lib/Util/Util.cpp b/mlir/lib/Util/Util.cpp index 4a4284ef8..c469bb2cb 100644 --- a/mlir/lib/Util/Util.cpp +++ b/mlir/lib/Util/Util.cpp @@ -1754,6 +1754,8 @@ air::findCommonRegionContainingAllAncestors(SmallVector ops, if (until && region->getParentOp() == until) return nullptr; region = region->getParentRegion(); + if (!region) + return nullptr; } return region; } From 72d467c8a883d912d40eb399c335d9c53d1e6e9a Mon Sep 17 00:00:00 2001 From: erwei-xilinx Date: Wed, 22 Jan 2025 18:23:46 -0800 Subject: [PATCH 6/6] Switch to use findCommonRegionContainingAllAncestors --- mlir/lib/Conversion/AIRToAIESchedulingUtils.cpp | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/mlir/lib/Conversion/AIRToAIESchedulingUtils.cpp b/mlir/lib/Conversion/AIRToAIESchedulingUtils.cpp index 1528e18bf..1a86ac9ca 100644 --- a/mlir/lib/Conversion/AIRToAIESchedulingUtils.cpp +++ b/mlir/lib/Conversion/AIRToAIESchedulingUtils.cpp @@ -262,17 +262,14 @@ air::getRepeatCounts(std::vector memcpy_ops) { memcpyIOps = uniqueMemcpyIPattern; // Get the deepest region which is ancestor to all memcpyIOps. - Region *commonRegion = memcpyIOps.front()->getParentRegion(); - while (!llvm::all_of(memcpyIOps, [commonRegion](Operation *o) { - return commonRegion->isAncestor(o->getParentRegion()); - })) { - commonRegion = commonRegion->getParentRegion(); - if (!commonRegion) - return repeatCounts; - } + SmallVector memcpyIOpVec = memcpyIOps.takeVector(); + Region *commonRegion = + air::findCommonRegionContainingAllAncestors(memcpyIOpVec); + if (!commonRegion) + return repeatCounts; // Get each memcpy op's repeat count, relative to the common region. - for (auto o : memcpyIOps) { + for (auto o : memcpyIOpVec) { int tripCount = 1; Region *currRegion = o->getParentRegion(); while (commonRegion->isAncestor(currRegion)) {