Skip to content

Commit

Permalink
fix build
Browse files Browse the repository at this point in the history
  • Loading branch information
zhen8838 committed Jan 9, 2025
1 parent 0e466aa commit da07884
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
9 changes: 5 additions & 4 deletions src/Nncase.Schedule/Schedule/GraphTiler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ public static (Dictionary<TieredTileGraph, Expr> ResultMemo, long ObjectValue) S

if (!solveMemo.TryGetValue(primTree, out var memo))
{
var result = SolvePrimGraph(primTree, bufferGraphMemo, targetOptions);
var result = SolvePrimGraph(primTree, bufferGraphMemo, targetOptions, moduleKind);
(inputBids, outputBids) = (result.Inputs, result.Outputs);
result.ScheduleBuffers();
var bodyBuilder = T.Sequential();
Expand Down Expand Up @@ -139,7 +139,7 @@ public static (Dictionary<TieredTileGraph, Expr> ResultMemo, long ObjectValue) S
return (resultMemo, objectValue);
}

public static TreeSolveResult SolvePrimGraph(TileNode primTree, Dictionary<TieredTileGraph, BufferGraph> bufferGraphMemo, ICpuTargetOptions targetOptions)
public static TreeSolveResult SolvePrimGraph(TileNode primTree, Dictionary<TieredTileGraph, BufferGraph> bufferGraphMemo, ICpuTargetOptions targetOptions, string moduleKind)
{
int[] memoryCapacities = targetOptions.MemoryCapacities;
int[] memoryBandWidths = targetOptions.MemoryBandWidths;
Expand Down Expand Up @@ -434,8 +434,9 @@ public static TreeSolveResult SolvePrimGraph(TileNode primTree, Dictionary<Tiere
{
var binfo = bid.Node.GetKernelInfo(targetOptions).BufferInfos;
var reused = nodeInfo.DefUseMap.ContainsKey(bid);
for (int storeLevel = 0; storeLevel < Math.Min(tileNode.Level, topLevel - 1); storeLevel++) // skip the buffer which store at top level
for (int storeLevel = 0; storeLevel < Math.Min(tileNode.Level, topLevel - 1); storeLevel++)
{
// skip the buffer which store at top level
var volumes = Enumerable.Repeat((IntExpr)solver.MakeIntConst(1), bufferInfo.Places.Length).ToArray();
for (int i = 0; i < bufferInfo.Places.Length; i++)
{
Expand Down Expand Up @@ -638,7 +639,7 @@ public static TreeSolveResult SolvePrimGraph(TileNode primTree, Dictionary<Tiere
DumpAssgin(primTree, new TreeSolverPythonPrinter(sol, solver, opNodeMemo, tileNodeMemo, tileableNodeMemo, targetOptions), tileVarConstraints, eachLevelStoreBufferNumsConstrains, levelBufferSizes, levelDataReads, levelDataWrites, memoryCycles, computeCycles, totalCyclesVar);
}

return new TreeSolveResult(bufferGraphMemo[primTree.Wrapped], sol.ObjectiveValue(), levelBufferSizesAssgin, levelBufferLifeness, opNodeMemoAssgin, tileNodeMemoAssgin, tileableNodeMemoAssgin, targetOptions);
return new TreeSolveResult(bufferGraphMemo[primTree.Wrapped], sol.ObjectiveValue(), levelBufferSizesAssgin, levelBufferLifeness, opNodeMemoAssgin, tileNodeMemoAssgin, tileableNodeMemoAssgin, targetOptions, moduleKind);
}

public static void DumpAssgin(ITreeNode tree, TreeSolverPythonPrinter printer, Dictionary<OpNode, Constraint[]> tileVarConstraints, Dictionary<BufferIdentity, Constraint[]> lowestStoreBufferNumsConstrains, Dictionary<int, Dictionary<NodeWithBuffer, IntExpr>> levelBufferSizes, IntExpr[] levelDataReads, IntExpr[] levelDataWrites, IntExpr[] memoryCycles, IntExpr computeCycles, IntVar totalCycles)
Expand Down
4 changes: 4 additions & 0 deletions src/Nncase.Schedule/Transforms/AutoTilePass.cs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,11 @@ private BaseFunction Rewrite(BaseFunction pre, int funcNumber, Dictionary<Schedu
var cloner = new ReplacingExprCloner(ctx.VarMap[si].ToDictionary(kv => kv.Key, kv => (Expr)kv.Value));
var clonedCall = cloner.Clone(vertex.Expr, default); // replaces some exprs that are in the subgraph with var, avoid tiling the grids out of the subgraph.
using var scope = new Diagnostics.DumpScope($"tiling_func{funcNumber}_subfunc{subFuncNumber}");
#if false
var tiledCall = GraphTiler.MCTSTiling(clonedCall, ModuleKind, $"{funcNumber}_{subFuncNumber}", memo, (ICpuTargetOptions)CompileOptions.TargetOptions);
#else
var tiledCall = GraphTiler.Tiling(clonedCall, ModuleKind, $"{funcNumber}_{subFuncNumber}", memo, (ICpuTargetOptions)CompileOptions.TargetOptions);
#endif

var varMap = ctx.VarMap[si].ToDictionary(kv => (Expr)kv.Value, kv => exprMemo[kv.Key]);
var substitutor = new Mutators.Substitutor(e =>
Expand Down

0 comments on commit da07884

Please sign in to comment.