From aad5731c84211ef1d0bd7e907a10b77b00d843a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=91=E5=90=AF=E8=88=AA?= <597323109@qq.com> Date: Tue, 14 Jan 2025 17:12:55 +0800 Subject: [PATCH] fix build --- src/Nncase.Schedule/Schedule/GraphTiler.cs | 249 +++++++++--------- .../Schedule/TileGraph/GraphMCTS.cs | 15 +- .../Transforms/AutoTilePass.cs | 8 +- src/Nncase.Tests/Affine/UnitTestTileGraph.cs | 4 +- 4 files changed, 134 insertions(+), 142 deletions(-) diff --git a/src/Nncase.Schedule/Schedule/GraphTiler.cs b/src/Nncase.Schedule/Schedule/GraphTiler.cs index 4978852e4..0cba90c4e 100644 --- a/src/Nncase.Schedule/Schedule/GraphTiler.cs +++ b/src/Nncase.Schedule/Schedule/GraphTiler.cs @@ -14,132 +14,11 @@ namespace Nncase.Schedule; -public static class GraphTiler +public class GraphTiler { - public static Expr MCTSTiling(Expr preExpr, string moduleKind, string prefix, Dictionary solveMemo, ICpuTargetOptions targetOptions) - { - var topLevel = targetOptions.MemoryCapacities.Length; - var rootGraph = GraphBuilder.Build(preExpr, topLevel, out var exprMemo); - var rootState = new MCTState(rootGraph, moduleKind, prefix, "0", solveMemo, targetOptions); - var rootNode = new MCTNode(rootState); - var searcher = new MCTSearcher(); - searcher.Search(rootNode); - if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.Tiling)) - { - rootNode.Dump("SearchTree"); - } - - var bestState = (MCTState)searcher.BestMCTNode!.State; - var replaces = new Dictionary(); - foreach (var (oldExpr, v) in exprMemo) - { - if (bestState.Results.TryGetValue(v, out var newExpr)) - { - replaces.Add(oldExpr, newExpr); - } - } - - var cloner = new ReplacingExprCloner(replaces); - return cloner.Clone(preExpr, default); - } - - public int DeviceFuncionCount { get; set; } - - public Expr Tile(Expr preExpr, string moduleKind, Dictionary solveMemo, ICpuTargetOptions targetOptions) - { - var topLevel = targetOptions.MemoryCapacities.Length; - var rootGraph = GraphBuilder.Build(preExpr, topLevel, out var exprMemo); - if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.Tiling)) - { - rootGraph.Dump($"tile_graph"); - } - - var (resultMemo, _) = SolveRootGraph(rootGraph, moduleKind, prefix, solveMemo, targetOptions); - var cloner = new ReplacingExprCloner(exprMemo.ToDictionary(kv => (Expr)kv.Key, kv => resultMemo[kv.Value])); - return cloner.Clone(preExpr, default); - } - - public static (Dictionary ResultMemo, long ObjectValue) SolveRootGraph(TieredTileGraph rootGraph, string moduleKind, string prefix, Dictionary solveMemo, ICpuTargetOptions targetOptions) - { - // bufferize root graph. - var bufferGraphMemo = rootGraph.Bufferize(); - if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.Tiling)) - { - bufferGraphMemo[rootGraph].Dump($"tile_buffer_graph"); - } - - // condense the root graph. - var condensedGraph = rootGraph.Condense(); - if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.Tiling)) - { - using (var file = Diagnostics.DumpScope.Current.OpenFile($"condensed_tile_graph.dot")) - { - using var writer = new StreamWriter(file); - writer.Write(condensedGraph.ToGraphviz(init => - { - init.FormatVertex += (_, arg) => - { - if (arg.Vertex is TieredTileGraph t) - { - arg.VertexFormat.Label = t.ToString(); - } - }; - })); - } - } + public int DeviceFuncionCount { get; private set; } - // convert root graph as tree. - var rootTree = TileNode.FromTileGraph(rootGraph, out var treeGraphMemo); - - var argumentsMemo = bufferGraphMemo[rootGraph].GetInputsOutputs().Inputs.ToDictionary(k => k, k => k.Node.Grid.GetArgument(k.Index)); - var resultMemo = new Dictionary(); - long objectValue = 0; - foreach (var (primGraph, i) in condensedGraph.TopologicalSort().Select((s, i) => (s, i))) - { - using var subSubScope = new Diagnostics.DumpScope($"device_func_{DeviceFuncionCount}", Diagnostics.DumpFlags.Tiling); - var primTree = treeGraphMemo[primGraph]; - HashSet inputBids; - HashSet outputBids; - - if (!solveMemo.TryGetValue(primTree, out var memo)) - { - var result = SolvePrimGraph(primTree, bufferGraphMemo, targetOptions, moduleKind); - (inputBids, outputBids) = (result.Inputs, result.Outputs); - result.ScheduleBuffers(); - var bodyBuilder = T.Sequential(); - result.Visit(primTree, new(bodyBuilder, Array.Empty())); - var parameters = inputBids.Concat(outputBids).Select(k => result.PrimBufferMemo[k]).ToArray(); - var funcBuilder = T.PrimFunc($"device_func_{DeviceFuncionCount++}", moduleKind, parameters).Body(bodyBuilder); - var primFunc = funcBuilder.Build(); - memo = new(new PrimFunctionWrapper(primFunc, inputBids.Count, inputBids.Concat(outputBids).Select(bid => bid.Node.Grid.GetArgument(bid.Index).CheckedType).ToArray()), result.ObjectiveValue); - solveMemo.Add(primTree, memo); - } - else - { - (inputBids, outputBids) = bufferGraphMemo[primGraph].GetInputsOutputs(); - } - - objectValue += memo.ObjectValue; - var finalCall = new Call(memo.Func, inputBids.Select(bid => argumentsMemo[bid]).ToArray()); - resultMemo.Add(primGraph, finalCall); - - // save the output. - foreach (var outputBid in outputBids) - { - if (!argumentsMemo.TryGetValue(outputBid, out var _)) - { - foreach (var outEdge in bufferGraphMemo[rootGraph].OutEdges(outputBid).Where(e => e.Tag is BufferEdgeKind.Outer)) - { - argumentsMemo.Add(outEdge.Target, finalCall); - } - - argumentsMemo.Add(outputBid, finalCall); - } - } - } - - return (resultMemo, objectValue); - } + public Dictionary SolveMemo { get; } = new Dictionary(new ITreeNodeComparer()); public static TreeSolveResult SolvePrimGraph(TileNode primTree, Dictionary bufferGraphMemo, ICpuTargetOptions targetOptions, string moduleKind) { @@ -705,6 +584,128 @@ public static void DumpAssgin(ITreeNode tree, TreeSolverPrinter printer, Diction } } + public (Dictionary ResultMemo, long ObjectValue) SolveRootGraph(TieredTileGraph rootGraph, string moduleKind, ICpuTargetOptions targetOptions) + { + // bufferize root graph. + var bufferGraphMemo = rootGraph.Bufferize(); + if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.Tiling)) + { + bufferGraphMemo[rootGraph].Dump($"tile_buffer_graph"); + } + + // condense the root graph. + var condensedGraph = rootGraph.Condense(); + if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.Tiling)) + { + using (var file = Diagnostics.DumpScope.Current.OpenFile($"condensed_tile_graph.dot")) + { + using var writer = new StreamWriter(file); + writer.Write(condensedGraph.ToGraphviz(init => + { + init.FormatVertex += (_, arg) => + { + if (arg.Vertex is TieredTileGraph t) + { + arg.VertexFormat.Label = t.ToString(); + } + }; + })); + } + } + + // convert root graph as tree. + var rootTree = TileNode.FromTileGraph(rootGraph, out var treeGraphMemo); + + var argumentsMemo = bufferGraphMemo[rootGraph].GetInputsOutputs().Inputs.ToDictionary(k => k, k => k.Node.Grid.GetArgument(k.Index)); + var resultMemo = new Dictionary(); + long objectValue = 0; + foreach (var (primGraph, i) in condensedGraph.TopologicalSort().Select((s, i) => (s, i))) + { + using var subSubScope = new Diagnostics.DumpScope($"device_func_{DeviceFuncionCount}", Diagnostics.DumpFlags.Tiling); + var primTree = treeGraphMemo[primGraph]; + HashSet inputBids; + HashSet outputBids; + + if (!SolveMemo.TryGetValue(primTree, out var memo)) + { + var result = SolvePrimGraph(primTree, bufferGraphMemo, targetOptions, moduleKind); + (inputBids, outputBids) = (result.Inputs, result.Outputs); + result.ScheduleBuffers(); + var bodyBuilder = T.Sequential(); + result.Visit(primTree, new(bodyBuilder, Array.Empty())); + var parameters = inputBids.Concat(outputBids).Select(k => result.PrimBufferMemo[k]).ToArray(); + var funcBuilder = T.PrimFunc($"device_func_{DeviceFuncionCount++}", moduleKind, parameters).Body(bodyBuilder); + var primFunc = funcBuilder.Build(); + memo = new(new PrimFunctionWrapper(primFunc, inputBids.Count, inputBids.Concat(outputBids).Select(bid => bid.Node.Grid.GetArgument(bid.Index).CheckedType).ToArray()), result.ObjectiveValue); + SolveMemo.Add(primTree, memo); + } + else + { + (inputBids, outputBids) = bufferGraphMemo[primGraph].GetInputsOutputs(); + } + + objectValue += memo.ObjectValue; + var finalCall = new Call(memo.Func, inputBids.Select(bid => argumentsMemo[bid]).ToArray()); + resultMemo.Add(primGraph, finalCall); + + // save the output. + foreach (var outputBid in outputBids) + { + if (!argumentsMemo.TryGetValue(outputBid, out var _)) + { + foreach (var outEdge in bufferGraphMemo[rootGraph].OutEdges(outputBid).Where(e => e.Tag is BufferEdgeKind.Outer)) + { + argumentsMemo.Add(outEdge.Target, finalCall); + } + + argumentsMemo.Add(outputBid, finalCall); + } + } + } + + return (resultMemo, objectValue); + } + + public Expr Tile(Expr preExpr, string moduleKind, ICpuTargetOptions targetOptions) + { +#if true + var topLevel = targetOptions.MemoryCapacities.Length; + var rootGraph = GraphBuilder.Build(preExpr, topLevel, out var exprMemo); + if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.Tiling)) + { + rootGraph.Dump($"tile_graph"); + } + + var (resultMemo, _) = SolveRootGraph(rootGraph, moduleKind, targetOptions); + var cloner = new ReplacingExprCloner(exprMemo.ToDictionary(kv => (Expr)kv.Key, kv => resultMemo[kv.Value])); + return cloner.Clone(preExpr, default); +#else + var topLevel = targetOptions.MemoryCapacities.Length; + var rootGraph = GraphBuilder.Build(preExpr, topLevel, out var exprMemo); + var rootState = new MCTState(rootGraph, moduleKind, "0", SolveMemo, targetOptions); + var rootNode = new MCTNode(rootState); + var searcher = new MCTSearcher(); + searcher.Search(rootNode); + if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.Tiling)) + { + rootNode.Dump("SearchTree"); + } + + var bestState = (MCTState)searcher.BestMCTNode!.State; + var replaces = new Dictionary(); + foreach (var (oldExpr, v) in exprMemo) + { + if (bestState.Results.TryGetValue(v, out var newExpr)) + { + replaces.Add(oldExpr, newExpr); + } + } + + var cloner = new ReplacingExprCloner(replaces); + return cloner.Clone(preExpr, default); +#endif + } + private static void DumpGantt(Dictionary nodeBufferSizes, Dictionary> nodeBufferLiveness, TileNode primTree, int storeLevel) { string GetStartStr(string name, int start) => $"[{name}] starts D+{start}"; diff --git a/src/Nncase.Schedule/Schedule/TileGraph/GraphMCTS.cs b/src/Nncase.Schedule/Schedule/TileGraph/GraphMCTS.cs index 856e743fa..57d74bd66 100644 --- a/src/Nncase.Schedule/Schedule/TileGraph/GraphMCTS.cs +++ b/src/Nncase.Schedule/Schedule/TileGraph/GraphMCTS.cs @@ -19,32 +19,29 @@ public sealed class MCTState : IEnvironmentState { private readonly string _path = string.Empty; + private readonly GraphTiler _graphTiler; + private readonly List _mergePoints = new(); private readonly List _legalIndex = new(); - private readonly Dictionary _solveMemo; - private readonly string _moduleKind; - private readonly string _prefix; - private readonly ICpuTargetOptions _targetOptions; private readonly TieredTileGraph _graph; private int _permformCount; - public MCTState(TieredTileGraph graph, string moduleKind, string prefix, string searchPath, Dictionary solveMemo, ICpuTargetOptions targetOptions) + public MCTState(TieredTileGraph graph, string moduleKind, string searchPath, GraphTiler graphTiler, ICpuTargetOptions targetOptions) { _graph = graph; _moduleKind = moduleKind; - _prefix = prefix; - _solveMemo = solveMemo; _targetOptions = targetOptions; _mergePoints.AddRange(graph.GetMergePoints()); _legalIndex.AddRange(Enumerable.Range(0, _mergePoints.Count)); _path = searchPath; + _graphTiler = graphTiler; Results = new(new LeafTileGraphComparer()); } @@ -70,7 +67,7 @@ public int LegalActions() var newGraph = _graph.Clone(); if (newGraph.Merge(mergePoint)) { - return new MCTState(newGraph, _moduleKind, _prefix, $"{_path}.{_permformCount}", _solveMemo, _targetOptions); + return new MCTState(newGraph, _moduleKind, $"{_path}.{_permformCount}", _graphTiler, _targetOptions); } return null; @@ -83,7 +80,7 @@ public double RollOut() using var scope = new Diagnostics.DumpScope($"RollOut{_path}"); try { - var res = GraphTiler.SolveRootGraph(_graph, _moduleKind, _prefix, _solveMemo, _targetOptions); + var res = _graphTiler.SolveRootGraph(_graph, _moduleKind, _targetOptions); ObjectValue = res.ObjectValue; foreach (var item in res.ResultMemo) { diff --git a/src/Nncase.Schedule/Transforms/AutoTilePass.cs b/src/Nncase.Schedule/Transforms/AutoTilePass.cs index 58fa43f44..dd23b905b 100644 --- a/src/Nncase.Schedule/Transforms/AutoTilePass.cs +++ b/src/Nncase.Schedule/Transforms/AutoTilePass.cs @@ -34,7 +34,7 @@ public AutoTilePass(string moduleKind, CompileOptions compileOptions) protected override Task RunCoreAsync(IRModule input, RunPassContext context) { - _ = new Dictionary(); + var tiler = new GraphTiler(); var funcNums = input.Functions.Count; for (int i = 0; i < funcNums; i++) { @@ -224,12 +224,6 @@ protected override Expr OnComplexCluster(ClusteredBidirectionalGraph { if (e is Var v && argumentDict.TryGetValue(v, out var arg)) diff --git a/src/Nncase.Tests/Affine/UnitTestTileGraph.cs b/src/Nncase.Tests/Affine/UnitTestTileGraph.cs index 0050ad490..86540e9ce 100644 --- a/src/Nncase.Tests/Affine/UnitTestTileGraph.cs +++ b/src/Nncase.Tests/Affine/UnitTestTileGraph.cs @@ -360,8 +360,8 @@ public void TestMCTS(Func functor, int count) builder.Visit(post); var tileGraph = builder.RootGraph; - var memo = new Dictionary(new ITreeNodeComparer()); - var state = new MCTState(tileGraph, "cpu", count.ToString(), string.Empty, memo, targetOptions); + var tiler = new Schedule.GraphTiler(); + var state = new MCTState(tileGraph, "cpu", count.ToString(), tiler, targetOptions); var rootNode = new MCTNode(state); var searcher = new MCTSearcher(); searcher.Search(rootNode);