Skip to content

Commit

Permalink
fix build
Browse files Browse the repository at this point in the history
  • Loading branch information
zhen8838 committed Jan 14, 2025
1 parent 608e275 commit aad5731
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 142 deletions.
249 changes: 125 additions & 124 deletions src/Nncase.Schedule/Schedule/GraphTiler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,132 +14,11 @@

namespace Nncase.Schedule;

public static class GraphTiler
public class GraphTiler
{
public static Expr MCTSTiling(Expr preExpr, string moduleKind, string prefix, Dictionary<TileNode, TiledFunc> solveMemo, ICpuTargetOptions targetOptions)
{
var topLevel = targetOptions.MemoryCapacities.Length;
var rootGraph = GraphBuilder.Build(preExpr, topLevel, out var exprMemo);
var rootState = new MCTState(rootGraph, moduleKind, prefix, "0", solveMemo, targetOptions);
var rootNode = new MCTNode(rootState);
var searcher = new MCTSearcher();
searcher.Search(rootNode);
if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.Tiling))
{
rootNode.Dump("SearchTree");
}

var bestState = (MCTState)searcher.BestMCTNode!.State;
var replaces = new Dictionary<Expr, Expr>();
foreach (var (oldExpr, v) in exprMemo)
{
if (bestState.Results.TryGetValue(v, out var newExpr))
{
replaces.Add(oldExpr, newExpr);
}
}

var cloner = new ReplacingExprCloner(replaces);
return cloner.Clone(preExpr, default);
}

public int DeviceFuncionCount { get; set; }

public Expr Tile(Expr preExpr, string moduleKind, Dictionary<TileNode, TiledFunc> solveMemo, ICpuTargetOptions targetOptions)
{
var topLevel = targetOptions.MemoryCapacities.Length;
var rootGraph = GraphBuilder.Build(preExpr, topLevel, out var exprMemo);
if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.Tiling))
{
rootGraph.Dump($"tile_graph");
}

var (resultMemo, _) = SolveRootGraph(rootGraph, moduleKind, prefix, solveMemo, targetOptions);
var cloner = new ReplacingExprCloner(exprMemo.ToDictionary(kv => (Expr)kv.Key, kv => resultMemo[kv.Value]));
return cloner.Clone(preExpr, default);
}

public static (Dictionary<TieredTileGraph, Expr> ResultMemo, long ObjectValue) SolveRootGraph(TieredTileGraph rootGraph, string moduleKind, string prefix, Dictionary<TileNode, TiledFunc> solveMemo, ICpuTargetOptions targetOptions)
{
// bufferize root graph.
var bufferGraphMemo = rootGraph.Bufferize();
if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.Tiling))
{
bufferGraphMemo[rootGraph].Dump($"tile_buffer_graph");
}

// condense the root graph.
var condensedGraph = rootGraph.Condense();
if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.Tiling))
{
using (var file = Diagnostics.DumpScope.Current.OpenFile($"condensed_tile_graph.dot"))
{
using var writer = new StreamWriter(file);
writer.Write(condensedGraph.ToGraphviz(init =>
{
init.FormatVertex += (_, arg) =>
{
if (arg.Vertex is TieredTileGraph t)
{
arg.VertexFormat.Label = t.ToString();
}
};
}));
}
}
public int DeviceFuncionCount { get; private set; }

// convert root graph as tree.
var rootTree = TileNode.FromTileGraph(rootGraph, out var treeGraphMemo);

var argumentsMemo = bufferGraphMemo[rootGraph].GetInputsOutputs().Inputs.ToDictionary(k => k, k => k.Node.Grid.GetArgument(k.Index));
var resultMemo = new Dictionary<TieredTileGraph, Expr>();
long objectValue = 0;
foreach (var (primGraph, i) in condensedGraph.TopologicalSort().Select((s, i) => (s, i)))
{
using var subSubScope = new Diagnostics.DumpScope($"device_func_{DeviceFuncionCount}", Diagnostics.DumpFlags.Tiling);
var primTree = treeGraphMemo[primGraph];
HashSet<BufferIdentity> inputBids;
HashSet<BufferIdentity> outputBids;

if (!solveMemo.TryGetValue(primTree, out var memo))
{
var result = SolvePrimGraph(primTree, bufferGraphMemo, targetOptions, moduleKind);
(inputBids, outputBids) = (result.Inputs, result.Outputs);
result.ScheduleBuffers();
var bodyBuilder = T.Sequential();
result.Visit(primTree, new(bodyBuilder, Array.Empty<Expr>()));
var parameters = inputBids.Concat(outputBids).Select(k => result.PrimBufferMemo[k]).ToArray();
var funcBuilder = T.PrimFunc($"device_func_{DeviceFuncionCount++}", moduleKind, parameters).Body(bodyBuilder);
var primFunc = funcBuilder.Build();
memo = new(new PrimFunctionWrapper(primFunc, inputBids.Count, inputBids.Concat(outputBids).Select(bid => bid.Node.Grid.GetArgument(bid.Index).CheckedType).ToArray()), result.ObjectiveValue);
solveMemo.Add(primTree, memo);
}
else
{
(inputBids, outputBids) = bufferGraphMemo[primGraph].GetInputsOutputs();
}

objectValue += memo.ObjectValue;
var finalCall = new Call(memo.Func, inputBids.Select(bid => argumentsMemo[bid]).ToArray());
resultMemo.Add(primGraph, finalCall);

// save the output.
foreach (var outputBid in outputBids)
{
if (!argumentsMemo.TryGetValue(outputBid, out var _))
{
foreach (var outEdge in bufferGraphMemo[rootGraph].OutEdges(outputBid).Where(e => e.Tag is BufferEdgeKind.Outer))
{
argumentsMemo.Add(outEdge.Target, finalCall);
}

argumentsMemo.Add(outputBid, finalCall);
}
}
}

return (resultMemo, objectValue);
}
public Dictionary<TileNode, TiledFunc> SolveMemo { get; } = new Dictionary<TileNode, TiledFunc>(new ITreeNodeComparer());

public static TreeSolveResult SolvePrimGraph(TileNode primTree, Dictionary<TieredTileGraph, BufferGraph> bufferGraphMemo, ICpuTargetOptions targetOptions, string moduleKind)
{
Expand Down Expand Up @@ -705,6 +584,128 @@ public static void DumpAssgin(ITreeNode tree, TreeSolverPrinter printer, Diction
}
}

public (Dictionary<TieredTileGraph, Expr> ResultMemo, long ObjectValue) SolveRootGraph(TieredTileGraph rootGraph, string moduleKind, ICpuTargetOptions targetOptions)
{
// bufferize root graph.
var bufferGraphMemo = rootGraph.Bufferize();
if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.Tiling))
{
bufferGraphMemo[rootGraph].Dump($"tile_buffer_graph");
}

// condense the root graph.
var condensedGraph = rootGraph.Condense();
if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.Tiling))
{
using (var file = Diagnostics.DumpScope.Current.OpenFile($"condensed_tile_graph.dot"))
{
using var writer = new StreamWriter(file);
writer.Write(condensedGraph.ToGraphviz(init =>
{
init.FormatVertex += (_, arg) =>
{
if (arg.Vertex is TieredTileGraph t)
{
arg.VertexFormat.Label = t.ToString();
}
};
}));
}
}

// convert root graph as tree.
var rootTree = TileNode.FromTileGraph(rootGraph, out var treeGraphMemo);

var argumentsMemo = bufferGraphMemo[rootGraph].GetInputsOutputs().Inputs.ToDictionary(k => k, k => k.Node.Grid.GetArgument(k.Index));
var resultMemo = new Dictionary<TieredTileGraph, Expr>();
long objectValue = 0;
foreach (var (primGraph, i) in condensedGraph.TopologicalSort().Select((s, i) => (s, i)))
{
using var subSubScope = new Diagnostics.DumpScope($"device_func_{DeviceFuncionCount}", Diagnostics.DumpFlags.Tiling);
var primTree = treeGraphMemo[primGraph];
HashSet<BufferIdentity> inputBids;
HashSet<BufferIdentity> outputBids;

if (!SolveMemo.TryGetValue(primTree, out var memo))
{
var result = SolvePrimGraph(primTree, bufferGraphMemo, targetOptions, moduleKind);
(inputBids, outputBids) = (result.Inputs, result.Outputs);
result.ScheduleBuffers();
var bodyBuilder = T.Sequential();
result.Visit(primTree, new(bodyBuilder, Array.Empty<Expr>()));
var parameters = inputBids.Concat(outputBids).Select(k => result.PrimBufferMemo[k]).ToArray();
var funcBuilder = T.PrimFunc($"device_func_{DeviceFuncionCount++}", moduleKind, parameters).Body(bodyBuilder);
var primFunc = funcBuilder.Build();
memo = new(new PrimFunctionWrapper(primFunc, inputBids.Count, inputBids.Concat(outputBids).Select(bid => bid.Node.Grid.GetArgument(bid.Index).CheckedType).ToArray()), result.ObjectiveValue);
SolveMemo.Add(primTree, memo);
}
else
{
(inputBids, outputBids) = bufferGraphMemo[primGraph].GetInputsOutputs();
}

objectValue += memo.ObjectValue;
var finalCall = new Call(memo.Func, inputBids.Select(bid => argumentsMemo[bid]).ToArray());
resultMemo.Add(primGraph, finalCall);

// save the output.
foreach (var outputBid in outputBids)
{
if (!argumentsMemo.TryGetValue(outputBid, out var _))
{
foreach (var outEdge in bufferGraphMemo[rootGraph].OutEdges(outputBid).Where(e => e.Tag is BufferEdgeKind.Outer))
{
argumentsMemo.Add(outEdge.Target, finalCall);
}

argumentsMemo.Add(outputBid, finalCall);
}
}
}

return (resultMemo, objectValue);
}

public Expr Tile(Expr preExpr, string moduleKind, ICpuTargetOptions targetOptions)
{
#if true
var topLevel = targetOptions.MemoryCapacities.Length;
var rootGraph = GraphBuilder.Build(preExpr, topLevel, out var exprMemo);
if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.Tiling))
{
rootGraph.Dump($"tile_graph");
}

var (resultMemo, _) = SolveRootGraph(rootGraph, moduleKind, targetOptions);
var cloner = new ReplacingExprCloner(exprMemo.ToDictionary(kv => (Expr)kv.Key, kv => resultMemo[kv.Value]));
return cloner.Clone(preExpr, default);
#else
var topLevel = targetOptions.MemoryCapacities.Length;
var rootGraph = GraphBuilder.Build(preExpr, topLevel, out var exprMemo);
var rootState = new MCTState(rootGraph, moduleKind, "0", SolveMemo, targetOptions);
var rootNode = new MCTNode(rootState);
var searcher = new MCTSearcher();
searcher.Search(rootNode);
if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.Tiling))
{
rootNode.Dump("SearchTree");
}

var bestState = (MCTState)searcher.BestMCTNode!.State;
var replaces = new Dictionary<Expr, Expr>();
foreach (var (oldExpr, v) in exprMemo)
{
if (bestState.Results.TryGetValue(v, out var newExpr))
{
replaces.Add(oldExpr, newExpr);
}
}

var cloner = new ReplacingExprCloner(replaces);
return cloner.Clone(preExpr, default);
#endif
}

private static void DumpGantt(Dictionary<NodeWithBuffer, IntExpr> nodeBufferSizes, Dictionary<NodeWithBuffer, Tuple<int, int>> nodeBufferLiveness, TileNode primTree, int storeLevel)
{
string GetStartStr(string name, int start) => $"[{name}] starts D+{start}";
Expand Down
15 changes: 6 additions & 9 deletions src/Nncase.Schedule/Schedule/TileGraph/GraphMCTS.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,32 +19,29 @@ public sealed class MCTState : IEnvironmentState<MergePoint>
{
private readonly string _path = string.Empty;

private readonly GraphTiler _graphTiler;

private readonly List<MergePoint> _mergePoints = new();

private readonly List<int> _legalIndex = new();

private readonly Dictionary<TileNode, GraphTiler.TiledFunc> _solveMemo;

private readonly string _moduleKind;

private readonly string _prefix;

private readonly ICpuTargetOptions _targetOptions;

private readonly TieredTileGraph _graph;

private int _permformCount;

public MCTState(TieredTileGraph graph, string moduleKind, string prefix, string searchPath, Dictionary<TileNode, GraphTiler.TiledFunc> solveMemo, ICpuTargetOptions targetOptions)
public MCTState(TieredTileGraph graph, string moduleKind, string searchPath, GraphTiler graphTiler, ICpuTargetOptions targetOptions)
{
_graph = graph;
_moduleKind = moduleKind;
_prefix = prefix;
_solveMemo = solveMemo;
_targetOptions = targetOptions;
_mergePoints.AddRange(graph.GetMergePoints());
_legalIndex.AddRange(Enumerable.Range(0, _mergePoints.Count));
_path = searchPath;
_graphTiler = graphTiler;
Results = new(new LeafTileGraphComparer());
}

Expand All @@ -70,7 +67,7 @@ public int LegalActions()
var newGraph = _graph.Clone();
if (newGraph.Merge(mergePoint))
{
return new MCTState(newGraph, _moduleKind, _prefix, $"{_path}.{_permformCount}", _solveMemo, _targetOptions);
return new MCTState(newGraph, _moduleKind, $"{_path}.{_permformCount}", _graphTiler, _targetOptions);
}

return null;
Expand All @@ -83,7 +80,7 @@ public double RollOut()
using var scope = new Diagnostics.DumpScope($"RollOut{_path}");
try
{
var res = GraphTiler.SolveRootGraph(_graph, _moduleKind, _prefix, _solveMemo, _targetOptions);
var res = _graphTiler.SolveRootGraph(_graph, _moduleKind, _targetOptions);
ObjectValue = res.ObjectValue;
foreach (var item in res.ResultMemo)
{
Expand Down
8 changes: 1 addition & 7 deletions src/Nncase.Schedule/Transforms/AutoTilePass.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public AutoTilePass(string moduleKind, CompileOptions compileOptions)

protected override Task<IRModule> RunCoreAsync(IRModule input, RunPassContext context)
{
_ = new Dictionary<Schedule.TileGraph.TileNode, GraphTiler.TiledFunc>();
var tiler = new GraphTiler();
var funcNums = input.Functions.Count;
for (int i = 0; i < funcNums; i++)
{
Expand Down Expand Up @@ -224,12 +224,6 @@ protected override Expr OnComplexCluster(ClusteredBidirectionalGraph<ExprVertex,

Expr cloned = clones.Count == 1 ? clones[0] : new IR.Tuple(clones.ToArray());
var tiled = Tiler.Tile(cloned, ModuleKind, (ICpuTargetOptions)CompileOptions.TargetOptions);

// #if false
// var tiledCall = GraphTiler.MCTSTiling(clonedCall, ModuleKind, memo, (ICpuTargetOptions)CompileOptions.TargetOptions);
// #else
// var tiledCall = GraphTiler.Tiling(clonedCall, ModuleKind, memo, (ICpuTargetOptions)CompileOptions.TargetOptions);
// #endif
var substitutor = new Mutators.Substitutor(e =>
{
if (e is Var v && argumentDict.TryGetValue(v, out var arg))
Expand Down
4 changes: 2 additions & 2 deletions src/Nncase.Tests/Affine/UnitTestTileGraph.cs
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,8 @@ public void TestMCTS(Func<Function> functor, int count)
builder.Visit(post);
var tileGraph = builder.RootGraph;

var memo = new Dictionary<TileNode, Schedule.GraphTiler.TiledFunc>(new ITreeNodeComparer());
var state = new MCTState(tileGraph, "cpu", count.ToString(), string.Empty, memo, targetOptions);
var tiler = new Schedule.GraphTiler();
var state = new MCTState(tileGraph, "cpu", count.ToString(), tiler, targetOptions);
var rootNode = new MCTNode(state);
var searcher = new MCTSearcher();
searcher.Search(rootNode);
Expand Down

0 comments on commit aad5731

Please sign in to comment.