From da078846c343e3fd389600747a99a60515cc46d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=91=E5=90=AF=E8=88=AA?= <597323109@qq.com> Date: Thu, 9 Jan 2025 17:02:46 +0800 Subject: [PATCH] fix build --- src/Nncase.Schedule/Schedule/GraphTiler.cs | 9 +++++---- src/Nncase.Schedule/Transforms/AutoTilePass.cs | 4 ++++ 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/Nncase.Schedule/Schedule/GraphTiler.cs b/src/Nncase.Schedule/Schedule/GraphTiler.cs index 66a991771d..2f9a2f563b 100644 --- a/src/Nncase.Schedule/Schedule/GraphTiler.cs +++ b/src/Nncase.Schedule/Schedule/GraphTiler.cs @@ -101,7 +101,7 @@ public static (Dictionary<TieredTileGraph, Expr> ResultMemo, long ObjectValue) S if (!solveMemo.TryGetValue(primTree, out var memo)) { - var result = SolvePrimGraph(primTree, bufferGraphMemo, targetOptions); + var result = SolvePrimGraph(primTree, bufferGraphMemo, targetOptions, moduleKind); (inputBids, outputBids) = (result.Inputs, result.Outputs); result.ScheduleBuffers(); var bodyBuilder = T.Sequential(); @@ -139,7 +139,7 @@ public static (Dictionary<TieredTileGraph, Expr> ResultMemo, long ObjectValue) S return (resultMemo, objectValue); } - public static TreeSolveResult SolvePrimGraph(TileNode primTree, Dictionary<TieredTileGraph, BufferGraph> bufferGraphMemo, ICpuTargetOptions targetOptions) + public static TreeSolveResult SolvePrimGraph(TileNode primTree, Dictionary<TieredTileGraph, BufferGraph> bufferGraphMemo, ICpuTargetOptions targetOptions, string moduleKind) { int[] memoryCapacities = targetOptions.MemoryCapacities; int[] memoryBandWidths = targetOptions.MemoryBandWidths; @@ -434,8 +434,9 @@ public static TreeSolveResult SolvePrimGraph(TileNode primTree, Dictionary<Tiere { var binfo = bid.Node.GetKernelInfo(targetOptions).BufferInfos; var reused = nodeInfo.DefUseMap.ContainsKey(bid); - for (int storeLevel = 0; storeLevel < Math.Min(tileNode.Level, topLevel - 1); storeLevel++) // skip the buffer which store at top level + for (int storeLevel = 0; storeLevel < Math.Min(tileNode.Level, topLevel - 1); storeLevel++) { + // skip the buffer which store at top level var volumes = Enumerable.Repeat((IntExpr)solver.MakeIntConst(1), bufferInfo.Places.Length).ToArray(); for (int i = 0; i < bufferInfo.Places.Length; i++) { @@ -638,7 +639,7 @@ public static TreeSolveResult SolvePrimGraph(TileNode primTree, Dictionary<Tiere DumpAssgin(primTree, new TreeSolverPythonPrinter(sol, solver, opNodeMemo, tileNodeMemo, tileableNodeMemo, targetOptions), tileVarConstraints, eachLevelStoreBufferNumsConstrains, levelBufferSizes, levelDataReads, levelDataWrites, memoryCycles, computeCycles, totalCyclesVar); } - return new TreeSolveResult(bufferGraphMemo[primTree.Wrapped], sol.ObjectiveValue(), levelBufferSizesAssgin, levelBufferLifeness, opNodeMemoAssgin, tileNodeMemoAssgin, tileableNodeMemoAssgin, targetOptions); + return new TreeSolveResult(bufferGraphMemo[primTree.Wrapped], sol.ObjectiveValue(), levelBufferSizesAssgin, levelBufferLifeness, opNodeMemoAssgin, tileNodeMemoAssgin, tileableNodeMemoAssgin, targetOptions, moduleKind); } public static void DumpAssgin(ITreeNode tree, TreeSolverPythonPrinter printer, Dictionary<OpNode, Constraint[]> tileVarConstraints, Dictionary<BufferIdentity, Constraint[]> lowestStoreBufferNumsConstrains, Dictionary<int, Dictionary<NodeWithBuffer, IntExpr>> levelBufferSizes, IntExpr[] levelDataReads, IntExpr[] levelDataWrites, IntExpr[] memoryCycles, IntExpr computeCycles, IntVar totalCycles) diff --git a/src/Nncase.Schedule/Transforms/AutoTilePass.cs b/src/Nncase.Schedule/Transforms/AutoTilePass.cs index b9885a6286..391b80eb1c 100644 --- a/src/Nncase.Schedule/Transforms/AutoTilePass.cs +++ b/src/Nncase.Schedule/Transforms/AutoTilePass.cs @@ -94,7 +94,11 @@ private BaseFunction Rewrite(BaseFunction pre, int funcNumber, Dictionary<Schedu var cloner = new ReplacingExprCloner(ctx.VarMap[si].ToDictionary(kv => kv.Key, kv => (Expr)kv.Value)); var clonedCall = cloner.Clone(vertex.Expr, default); // replaces some exprs that are in the subgraph with var, avoid tiling the grids out of the subgraph. using var scope = new Diagnostics.DumpScope($"tiling_func{funcNumber}_subfunc{subFuncNumber}"); +#if false var tiledCall = GraphTiler.MCTSTiling(clonedCall, ModuleKind, $"{funcNumber}_{subFuncNumber}", memo, (ICpuTargetOptions)CompileOptions.TargetOptions); +#else + var tiledCall = GraphTiler.Tiling(clonedCall, ModuleKind, $"{funcNumber}_{subFuncNumber}", memo, (ICpuTargetOptions)CompileOptions.TargetOptions); +#endif var varMap = ctx.VarMap[si].ToDictionary(kv => (Expr)kv.Value, kv => exprMemo[kv.Key]); var substitutor = new Mutators.Substitutor(e =>