From da078846c343e3fd389600747a99a60515cc46d3 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E9=83=91=E5=90=AF=E8=88=AA?= <597323109@qq.com>
Date: Thu, 9 Jan 2025 17:02:46 +0800
Subject: [PATCH] fix build

---
 src/Nncase.Schedule/Schedule/GraphTiler.cs     | 9 +++++----
 src/Nncase.Schedule/Transforms/AutoTilePass.cs | 4 ++++
 2 files changed, 9 insertions(+), 4 deletions(-)

diff --git a/src/Nncase.Schedule/Schedule/GraphTiler.cs b/src/Nncase.Schedule/Schedule/GraphTiler.cs
index 66a991771d..2f9a2f563b 100644
--- a/src/Nncase.Schedule/Schedule/GraphTiler.cs
+++ b/src/Nncase.Schedule/Schedule/GraphTiler.cs
@@ -101,7 +101,7 @@ public static (Dictionary<TieredTileGraph, Expr> ResultMemo, long ObjectValue) S
 
             if (!solveMemo.TryGetValue(primTree, out var memo))
             {
-                var result = SolvePrimGraph(primTree, bufferGraphMemo, targetOptions);
+                var result = SolvePrimGraph(primTree, bufferGraphMemo, targetOptions, moduleKind);
                 (inputBids, outputBids) = (result.Inputs, result.Outputs);
                 result.ScheduleBuffers();
                 var bodyBuilder = T.Sequential();
@@ -139,7 +139,7 @@ public static (Dictionary<TieredTileGraph, Expr> ResultMemo, long ObjectValue) S
         return (resultMemo, objectValue);
     }
 
-    public static TreeSolveResult SolvePrimGraph(TileNode primTree, Dictionary<TieredTileGraph, BufferGraph> bufferGraphMemo, ICpuTargetOptions targetOptions)
+    public static TreeSolveResult SolvePrimGraph(TileNode primTree, Dictionary<TieredTileGraph, BufferGraph> bufferGraphMemo, ICpuTargetOptions targetOptions, string moduleKind)
     {
         int[] memoryCapacities = targetOptions.MemoryCapacities;
         int[] memoryBandWidths = targetOptions.MemoryBandWidths;
@@ -434,8 +434,9 @@ public static TreeSolveResult SolvePrimGraph(TileNode primTree, Dictionary<Tiere
             {
                 var binfo = bid.Node.GetKernelInfo(targetOptions).BufferInfos;
                 var reused = nodeInfo.DefUseMap.ContainsKey(bid);
-                for (int storeLevel = 0; storeLevel < Math.Min(tileNode.Level, topLevel - 1); storeLevel++) // skip the buffer which store at top level
+                for (int storeLevel = 0; storeLevel < Math.Min(tileNode.Level, topLevel - 1); storeLevel++)
                 {
+                    // skip the buffer which store at top level
                     var volumes = Enumerable.Repeat((IntExpr)solver.MakeIntConst(1), bufferInfo.Places.Length).ToArray();
                     for (int i = 0; i < bufferInfo.Places.Length; i++)
                     {
@@ -638,7 +639,7 @@ public static TreeSolveResult SolvePrimGraph(TileNode primTree, Dictionary<Tiere
             DumpAssgin(primTree, new TreeSolverPythonPrinter(sol, solver, opNodeMemo, tileNodeMemo, tileableNodeMemo, targetOptions), tileVarConstraints, eachLevelStoreBufferNumsConstrains, levelBufferSizes, levelDataReads, levelDataWrites, memoryCycles, computeCycles, totalCyclesVar);
         }
 
-        return new TreeSolveResult(bufferGraphMemo[primTree.Wrapped], sol.ObjectiveValue(), levelBufferSizesAssgin, levelBufferLifeness, opNodeMemoAssgin, tileNodeMemoAssgin, tileableNodeMemoAssgin, targetOptions);
+        return new TreeSolveResult(bufferGraphMemo[primTree.Wrapped], sol.ObjectiveValue(), levelBufferSizesAssgin, levelBufferLifeness, opNodeMemoAssgin, tileNodeMemoAssgin, tileableNodeMemoAssgin, targetOptions, moduleKind);
     }
 
     public static void DumpAssgin(ITreeNode tree, TreeSolverPythonPrinter printer, Dictionary<OpNode, Constraint[]> tileVarConstraints, Dictionary<BufferIdentity, Constraint[]> lowestStoreBufferNumsConstrains, Dictionary<int, Dictionary<NodeWithBuffer, IntExpr>> levelBufferSizes, IntExpr[] levelDataReads, IntExpr[] levelDataWrites, IntExpr[] memoryCycles, IntExpr computeCycles, IntVar totalCycles)
diff --git a/src/Nncase.Schedule/Transforms/AutoTilePass.cs b/src/Nncase.Schedule/Transforms/AutoTilePass.cs
index b9885a6286..391b80eb1c 100644
--- a/src/Nncase.Schedule/Transforms/AutoTilePass.cs
+++ b/src/Nncase.Schedule/Transforms/AutoTilePass.cs
@@ -94,7 +94,11 @@ private BaseFunction Rewrite(BaseFunction pre, int funcNumber, Dictionary<Schedu
                 var cloner = new ReplacingExprCloner(ctx.VarMap[si].ToDictionary(kv => kv.Key, kv => (Expr)kv.Value));
                 var clonedCall = cloner.Clone(vertex.Expr, default); // replaces some exprs that are in the subgraph with var, avoid tiling the grids out of the subgraph.
                 using var scope = new Diagnostics.DumpScope($"tiling_func{funcNumber}_subfunc{subFuncNumber}");
+#if false
                 var tiledCall = GraphTiler.MCTSTiling(clonedCall, ModuleKind, $"{funcNumber}_{subFuncNumber}", memo, (ICpuTargetOptions)CompileOptions.TargetOptions);
+#else
+                var tiledCall = GraphTiler.Tiling(clonedCall, ModuleKind, $"{funcNumber}_{subFuncNumber}", memo, (ICpuTargetOptions)CompileOptions.TargetOptions);
+#endif
 
                 var varMap = ctx.VarMap[si].ToDictionary(kv => (Expr)kv.Value, kv => exprMemo[kv.Key]);
                 var substitutor = new Mutators.Substitutor(e =>