diff --git a/modules/Nncase.Modules.CPU/CodeGen/CPU/CSourceBuiltn.cs b/modules/Nncase.Modules.CPU/CodeGen/CPU/CSourceBuiltn.cs index 260c963d5..caabc89f0 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/CPU/CSourceBuiltn.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/CPU/CSourceBuiltn.cs @@ -12,7 +12,7 @@ public record BufferRenderInfo(string Name, string ElemType, ulong Offset, ulong { } -public record KernelMainModel(TIR.PrimFunction PrimFunction, TIR.Buffer[] RDataBuffers, CpuTargetOptions Options, ulong Alignment, ulong DataSize, ulong RDataSize) +public record KernelMainModel(TIR.PrimFunction PrimFunction, TIR.Buffer[] RDataBuffers, CpuTargetOptions Options, ulong Alignment, ulong DataSize, ulong RDataSize, ulong LocalRdataPoolSize) { public BufferRenderInfo GetInfo(TIR.Buffer buffer) { @@ -64,9 +64,9 @@ public static string CMakeDef(string name) return content; } - public static string MakeMain(TIR.PrimFunction primFunction, ulong dataAlign, ulong dataUsage, ulong rdataPoolSize, IEnumerable rdataBuffers, CpuTargetOptions options) + public static string MakeMain(TIR.PrimFunction primFunction, ulong dataAlign, ulong dataUsage, ulong rdataPoolSize, ulong localRdataPoolSize, IEnumerable rdataBuffers, CpuTargetOptions options) { - var content = RazorTemplateEngine.RenderAsync("~/CodeGen/CPU/Templates/thread_main.cpp.cshtml", new KernelMainModel(primFunction, rdataBuffers.ToArray(), options, dataAlign, dataUsage, rdataPoolSize)).Result; + var content = RazorTemplateEngine.RenderAsync("~/CodeGen/CPU/Templates/thread_main.cpp.cshtml", new KernelMainModel(primFunction, rdataBuffers.ToArray(), options, dataAlign, dataUsage, rdataPoolSize, localRdataPoolSize)).Result; return content; } diff --git a/modules/Nncase.Modules.CPU/CodeGen/CPU/FunctionBuilder.cs b/modules/Nncase.Modules.CPU/CodeGen/CPU/FunctionBuilder.cs index 4f34ff762..5b7d6446d 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/CPU/FunctionBuilder.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/CPU/FunctionBuilder.cs @@ -82,7 +82,7 @@ public unsafe ILinkableFunction Build(TIR.PrimFunction function) } // 3. build function. - var visitor = new KernelCSourceConvertVisitor(function.SchedResult.DataAlign, function.SchedResult.DataUsage, rdataPoolSize, TargetOptions); + var visitor = new KernelCSourceConvertVisitor(function.SchedResult.DataAlign, function.SchedResult.DataUsage, rdataPoolSize, localRdataPoolSize, TargetOptions); visitor.Visit(function); var functionCSource = visitor.GetCSource(); diff --git a/modules/Nncase.Modules.CPU/CodeGen/CPU/KernelCSourceConvertVisitor.cs b/modules/Nncase.Modules.CPU/CodeGen/CPU/KernelCSourceConvertVisitor.cs index 4213a0112..9ad7d8955 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/CPU/KernelCSourceConvertVisitor.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/CPU/KernelCSourceConvertVisitor.cs @@ -119,11 +119,12 @@ internal sealed class KernelCSourceConvertVisitor : ExprFunctor, private readonly StringWriter _sharedWriter; private ulong _collective_pool_size; - public KernelCSourceConvertVisitor(ulong dataAlign, ulong dataUsage, ulong rdataPoolSize, CpuTargetOptions targetOptions) + public KernelCSourceConvertVisitor(ulong dataAlign, ulong dataUsage, ulong rdataPoolSize, ulong localRdataPoolSize, CpuTargetOptions targetOptions) { DataAlign = dataAlign; DataUsage = dataUsage; RdataPoolSize = rdataPoolSize; + LocalRdataPoolSize = localRdataPoolSize; _kernelBuilder = new StringBuilder(); _sharedBuilder = new StringBuilder(); _sharedWriter = new StringWriter(_sharedBuilder); @@ -145,11 +146,13 @@ public KernelCSourceConvertVisitor(ulong dataAlign, ulong dataUsage, ulong rdata public ulong RdataPoolSize { get; } + public ulong LocalRdataPoolSize { get; } + public KernelCSource GetCSource() { var ctype = $"void {VisitEntry.Name}({string.Join(", ", VisitEntry.Parameters.AsValueEnumerable().Select(Visit).Select(s => $"{s.Type} {s.Name}").ToArray().Concat(_exprMemo.Keys.OfType().Where(b => b.MemSpan.Location is MemoryLocation.Rdata or MemoryLocation.ThreadLocalRdata).Select(Visit).Select(s => $" {s.Type} {s.Name}").ToArray()))}, uint8_t* data)"; return new( - CSourceBuiltn.MakeMain(VisitEntry, DataAlign, DataUsage, RdataPoolSize, _exprMemo.Keys.OfType().Where(b => b.MemSpan.Location is MemoryLocation.Rdata or MemoryLocation.ThreadLocalRdata), TargetOptions), + CSourceBuiltn.MakeMain(VisitEntry, DataAlign, DataUsage, RdataPoolSize, LocalRdataPoolSize, _exprMemo.Keys.OfType().Where(b => b.MemSpan.Location is MemoryLocation.Rdata or MemoryLocation.ThreadLocalRdata), TargetOptions), CSourceBuiltn.MakeKernel(ctype, _kernelBuilder.ToString()), CSourceBuiltn.TopoAwareRuntimeDef(TargetOptions, DataAlign, _collective_pool_size), CSourceBuiltn.TopologyDef(TargetOptions)); diff --git a/modules/Nncase.Modules.CPU/CodeGen/CPU/Templates/thread_main.cpp.cshtml b/modules/Nncase.Modules.CPU/CodeGen/CPU/Templates/thread_main.cpp.cshtml index 2270dabec..12a58e7b6 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/CPU/Templates/thread_main.cpp.cshtml +++ b/modules/Nncase.Modules.CPU/CodeGen/CPU/Templates/thread_main.cpp.cshtml @@ -64,6 +64,11 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char** argv) { } std::byte* rdata = (std::byte *)nncase::ntt::runtime::thread_alloc(@Model.RDataSize, align); + std::byte* local_rdata = (std::byte *)nncase::ntt::runtime::thread_alloc(@Model.LocalRdataPoolSize, align); + uint64_t local_rdata_header[@Model.Options.Hierarchies[0][^1] * 2]; + for (size_t tid = 0; tid < tdim(); tid++) { + local_rdata_header[tid * 2] = tid * ( @Model.LocalRdataPoolSize / tdim()); + } #ifdef __APPLE__ pthread_key_t cpu_thread_context_key_ = {}; @@ -73,7 +78,7 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char** argv) { std::vector blocks; for (size_t cid = 0; cid < cdim(); cid++) { for (size_t bid = 0; bid < bdim(); bid++) { - blocks.emplace_back([cid, bid, inputs, rdata + blocks.emplace_back([cid, bid, inputs, rdata, local_rdata_header, local_rdata #ifdef __APPLE__ , &cpu_thread_context_key_ #endif @@ -87,6 +92,8 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char** argv) { .cpu_id_offset = (cid * bdim() + bid) * tdim(), .inouts = inputs, .rdata = rdata, + .local_rdata_header = local_rdata_header, + .local_rdata = local_rdata, #ifdef __APPLE__ .cpu_thread_context_key = cpu_thread_context_key_, #endif diff --git a/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Binary.cs b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Binary.cs index 9deea14fe..427b73993 100644 --- a/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Binary.cs +++ b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Binary.cs @@ -9,7 +9,7 @@ namespace Nncase.Evaluator.TIR.CPU; -public sealed class BinaryEvaluator : ITypeInferencer, IKernelInfoEvaluator +public sealed class BinaryEvaluator : ITypeInferencer, IKernelInfoEvaluator, IOpPrinter { public IRType Visit(ITypeInferenceContext context, Binary target) { @@ -29,6 +29,11 @@ public MicroKernelInfo Visit(Binary op, MicroKernelContext context) return new MicroKernelInfo(primitives, multipliers, bufferInfos, GetComputeCycle); } + public string Visit(IIRPrinterContext context, Binary target, bool iLmode) + { + return $"Binary({target.DisplayProperty()}, {context.GetArgument(target, Binary.Lhs)}, {context.GetArgument(target, Binary.Rhs)}, {context.GetArgument(target, Binary.Output)})"; + } + private static IntExpr GetComputeCycle(IntExpr[][] bufferShapes, Solver solver, MicroKernelContext context) { var factora = System.Math.Min(context.BufferShapes[0][^1], 32); diff --git a/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Unary.cs b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Unary.cs index 2f298ef60..d58b90f4a 100644 --- a/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Unary.cs +++ b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Unary.cs @@ -10,7 +10,7 @@ namespace Nncase.Evaluator.TIR.CPU; -public sealed class UnaryEvaluator : ITypeInferencer, IKernelInfoEvaluator +public sealed class UnaryEvaluator : ITypeInferencer, IKernelInfoEvaluator, IOpPrinter { public IRType Visit(ITypeInferenceContext context, Unary target) { @@ -31,6 +31,11 @@ public MicroKernelInfo Visit(Unary op, MicroKernelContext context) return new MicroKernelInfo(primitives, multipliers, bufferInfos, GetComputeCycle); } + public string Visit(IIRPrinterContext context, Unary target, bool iLmode) + { + return $"Unary({target.DisplayProperty()}, {context.GetArgument(target, Unary.Input)}, {context.GetArgument(target, Unary.Output)})"; + } + private static IntExpr GetComputeCycle(IntExpr[][] bufferShapes, Solver solver, MicroKernelContext context) { var factor = System.Math.Min(context.BufferShapes[0][^1], 32); diff --git a/modules/Nncase.Modules.CPU/Passes/CPUFunctionPartitionPass.cs b/modules/Nncase.Modules.CPU/Passes/CPUFunctionPartitionPass.cs index 815f0d1c1..c54c99d3b 100644 --- a/modules/Nncase.Modules.CPU/Passes/CPUFunctionPartitionPass.cs +++ b/modules/Nncase.Modules.CPU/Passes/CPUFunctionPartitionPass.cs @@ -2,10 +2,13 @@ // Licensed under the Apache license. See LICENSE file in the project root for full license information. using NetFabric.Hyperlinq; +using Nncase.Graphs; using Nncase.IR; using Nncase.IR.Tensors; using Nncase.Passes.GraphPartition; using QuikGraph; +using QuikGraph.Algorithms; +using QuikGraph.Graphviz; namespace Nncase.Passes; @@ -23,85 +26,199 @@ protected override Task RunCoreAsync(IRModule module, RunPassContext c var funcs = module.Functions.Count; for (int i = 0; i < funcs; i++) { - if (module.Functions[i] is Function function) + if (module.Functions[i] is not Function function) { - Function pre = function; + continue; + } - // Function post; - var ctx = new GraphPartition.GraphContext(); - var convertor = new GraphPartition.GraphConvertor(x => x switch - { - Call call => (call.Target is IR.CPU.Boxing || call.CheckedType is DistributedType) ? true : false, - IR.Tuple tp => tp.Fields.ToArray().Any(f => f is Call { Target: IR.CPU.Boxing } b && b.CheckedType is TensorType) ? false : true, - _ => throw new NotSupportedException(), - }); - convertor.Visit(pre.Body, ctx); - -#if Debug - ctx.Graph.DumpDot(DumpScope.Current.Directory + $"function_{i}.dot"); -#endif - - ctx.SummarizeGraph(); - -#if Debug - ctx.GraphSummary.DumpDot(DumpScope.Current.Directory + $"function_{i}_summary.dot"); -#endif - - var dfsVisitor = new QuikGraph.Algorithms.TopologicalSort.SourceFirstTopologicalSortAlgorithm(ctx.GraphSummary); - dfsVisitor.Compute(); - var exprMemo = new Dictionary(ReferenceEqualityComparer.Instance); - for (var vi = 0; vi < dfsVisitor.SortedVertices.Length; vi++) + Function pre = function; + var postBody = PerformPartition(pre.Name, pre.Body); + var post = pre.With(pre.Name, pre.ModuleKind, postBody, pre.Parameters.ToArray()); + module.Replace(i, post); + } + + return Task.FromResult(module); + } + + private Expr PerformPartition(string funcName, Expr pre) + { + // 1. convert to quikgraph + var biGraph = new BidirectionalGraph(false); + { + var graphConvertor = new ExprGraphConvertor(); + graphConvertor.Visit(pre, biGraph); + } + + // 2. perform condensation + var condenseAlgo = new CondensationGraphAlgorithm(biGraph); + condenseAlgo.IsEdgeCompatible += (algo, arg) => + { + bool CheckField(Expr f) + { + if (f is Call c && c.Target is IR.CPU.Boxing { NewType: TensorType } && c.Arguments[0].CheckedType is DistributedType) { - var vertex = dfsVisitor.SortedVertices[vi]; - var subgraph = ctx.SubgraphMap[ctx.SummaryVertexSubgraphMap[vertex]]; - if (vertex.CompatType == Compat.INCOMPATIBLE) + return true; + } + + return f.CheckedType is DistributedType; + } + + bool isSupport = false; + switch (arg.Edge.Target.Expr) + { + case Call call: + if (call.Target is IR.CPU.Boxing { NewType: TensorType } && call.Arguments[0].CheckedType is DistributedType) { - var sg = new Graph(); - subgraph.Nodes.ForEach(n => sg.AddVertex(n)); - subgraph.InteriorEdges.ForEach(e => sg.AddEdge(e)); - - // sg.DumpDot(DumpScope.Current.Directory + $"_Incompatible_{subgraph.Index}_{vi}.dot"); - var sgVisitor = new QuikGraph.Algorithms.TopologicalSort.SourceFirstTopologicalSortAlgorithm(sg); - sgVisitor.Compute(); - foreach (var v in sgVisitor.SortedVertices) + isSupport = true; + } + else if (call.Target is IR.CPU.Boxing { NewType: DistributedType }) + { + if (arg.Edge.Source.Expr.CheckedType is not TensorType) { - var expr = v.Expr switch - { - Call c => c.With(arguments: c.Arguments.AsValueEnumerable().Select(arg => exprMemo[arg]).ToArray()), - IR.Tuple t => t.With(fields: t.Fields.AsValueEnumerable().Select(arg => exprMemo[arg]).ToArray()), - _ => v.Expr, - }; - exprMemo.Add(v.Expr, expr); + isSupport = true; } } - else + else if (call.CheckedType is DistributedType) { - var newInputs = ctx.VarMap[ctx.SummaryVertexSubgraphMap[vertex]].Values.ToArray(); - var merger = new Rules.FusionMerger(ctx.VarMap[ctx.SummaryVertexSubgraphMap[vertex]]); - var clonedRoot = merger.Clone(vertex.Expr, default); + isSupport = true; + } - if (clonedRoot is IR.Tuple tuple) - { - clonedRoot = new IR.Tuple(tuple.Fields.AsValueEnumerable().Select(f => f.CheckedType is DistributedType d ? IR.F.CPU.Boxing(f, d.TensorType) : f).ToArray()); - } + break; + case IR.Tuple tp: + isSupport = tp.Fields.AsValueEnumerable().All(f => f is Call c && CheckField(c)) ? true : false; + break; + default: + break; + } - var rootCall = new Call(new Fusion($"Function_{i}_fusion_{vi}_kernel", ModuleKind, clonedRoot, newInputs), ctx.VarMap[ctx.SummaryVertexSubgraphMap[vertex]].Keys.Select(e => exprMemo[e]).ToArray()); - if (ctx.OutputMap[subgraph.Index].Count > 1) - { - ctx.OutputMap[subgraph.Index].ToList().ForEach(e => exprMemo.Add(e.Key, new Call(new GetItem(), rootCall, e.Value))); - } - else - { - exprMemo.Add(ctx.OutputMap[subgraph.Index].Keys.First(), rootCall); - } + return isSupport; + }; + + condenseAlgo.IsGraphCompatible += (algo, edge) => + { + return algo.CondensedGraph.IsDirectedAcyclicGraph(); + }; + + condenseAlgo.Compute(); + + if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.Rewrite)) + { + condenseAlgo.CondensedGraph.Dump($"{funcName}Condensed", init => { }); + condenseAlgo.ClusteredGraph.Dump($"{funcName}Cluster", algo => + { + algo.FormatVertex += (s, arg) => + { + arg.VertexFormat.Label = $"{arg.Vertex.Expr.GetType().Name}"; + }; + }); + } + + // 3. reconstruction + var constructor = new DistributedReConstructor(funcName, ModuleKind, condenseAlgo); + var post = constructor.Construct(); + return post; + } +} + +internal sealed class DistributedReConstructor : ExprReConstructor +{ + public DistributedReConstructor(string funcName, string moduleKind, CondensationGraphAlgorithm algo) + : base(algo) + { + FuncName = funcName; + ModuleKind = moduleKind; + } + + public string FuncName { get; } + + public string ModuleKind { get; } + + protected override Expr OnComplexCluster(ClusteredBidirectionalGraph cluster, int sortIndex) + { + var pairs = GetClusterArgumentPairs(cluster); + var paramDict = new Dictionary(ReferenceEqualityComparer.Instance); + var extractDict = new Dictionary(ReferenceEqualityComparer.Instance); + var argumentDict = new Dictionary(ReferenceEqualityComparer.Instance); + foreach (var (pre, post) in pairs) + { + if (pre is Const) + { + continue; + } + + Var @var; + Expr extract; + if (pre.CheckedType is DistributedType d) + { + @var = new Var(d.TensorType); + extract = IR.F.CPU.Boxing(@var, d); + } + else + { + @var = new Var(pre.CheckedType); + extract = @var; + } + + var added = paramDict.TryAdd(pre, @var); + if (added) + { + extractDict.Add(pre, extract); + argumentDict.Add(@var, post); + } + } + + var cloner = new ExprClusterCloner(extractDict); + var outVertices = cluster.OutVertices().ToArray(); + var clones = new List(); + foreach (var outVertex in outVertices) + { + clones.Add(cloner.Clone(outVertex.Expr, default)); + } + + var cloned = PostProcess(clones); + var fusion = new Fusion($"{FuncName}_{sortIndex}_kernel", ModuleKind, cloned, paramDict.Values.OfType().ToArray()); + return new Call(fusion, paramDict.Values.OfType().Select(v => argumentDict[v]).ToArray()); + } + + private Expr PostProcess(List clones) + { + Expr PostProcessSingle(Expr cloned, out bool changed) + { + changed = false; + switch (cloned) + { + case IR.Tuple tp: + var nFields = new List(); + foreach (var item in tp.Fields) + { + nFields.Add(PostProcessSingle(item, out var childChanged)); + changed |= childChanged; + } + + if (changed) + { + return new IR.Tuple(nFields.ToArray()); + } + else + { + return tp; } - } - var post = pre.With(pre.Name, pre.ModuleKind, exprMemo[pre.Body], pre.Parameters.ToArray()); - module.Replace(i, post); + case Expr e when e.CheckedType is DistributedType d: + changed = true; + return IR.F.CPU.Boxing(e, d.TensorType); + default: + return cloned; } } - return Task.FromResult(module); + if (clones.Count == 1) + { + return PostProcessSingle(clones[0], out _); + } + else + { + return new IR.Tuple(clones.Select(c => PostProcessSingle(c, out _)).ToArray()); + } } } diff --git a/modules/Nncase.Modules.CPU/packages.lock.json b/modules/Nncase.Modules.CPU/packages.lock.json index d9d3e1b58..1b003b1c1 100644 --- a/modules/Nncase.Modules.CPU/packages.lock.json +++ b/modules/Nncase.Modules.CPU/packages.lock.json @@ -186,7 +186,8 @@ "dependencies": { "Nncase.Core": "[1.0.0, )", "Nncase.Evaluator": "[1.0.0, )", - "QuikGraph": "[2.5.0, )" + "QuikGraph": "[2.5.0, )", + "QuikGraph.Graphviz": "[2.5.0, )" } }, "nncase.io": { diff --git a/modules/Nncase.Modules.StackVM/packages.lock.json b/modules/Nncase.Modules.StackVM/packages.lock.json index e8b2a0b12..bea955106 100644 --- a/modules/Nncase.Modules.StackVM/packages.lock.json +++ b/modules/Nncase.Modules.StackVM/packages.lock.json @@ -174,7 +174,8 @@ "dependencies": { "Nncase.Core": "[1.0.0, )", "Nncase.Evaluator": "[1.0.0, )", - "QuikGraph": "[2.5.0, )" + "QuikGraph": "[2.5.0, )", + "QuikGraph.Graphviz": "[2.5.0, )" } }, "nncase.io": { diff --git a/src/Nncase.Cli/packages.lock.json b/src/Nncase.Cli/packages.lock.json index b83346286..49433d8de 100644 --- a/src/Nncase.Cli/packages.lock.json +++ b/src/Nncase.Cli/packages.lock.json @@ -727,7 +727,8 @@ "dependencies": { "Nncase.Core": "[1.0.0, )", "Nncase.Evaluator": "[1.0.0, )", - "QuikGraph": "[2.5.0, )" + "QuikGraph": "[2.5.0, )", + "QuikGraph.Graphviz": "[2.5.0, )" } }, "nncase.importer": { diff --git a/src/Nncase.Compiler/packages.lock.json b/src/Nncase.Compiler/packages.lock.json index 6bfb830ed..a99b3a43d 100644 --- a/src/Nncase.Compiler/packages.lock.json +++ b/src/Nncase.Compiler/packages.lock.json @@ -708,7 +708,8 @@ "dependencies": { "Nncase.Core": "[1.0.0, )", "Nncase.Evaluator": "[1.0.0, )", - "QuikGraph": "[2.5.0, )" + "QuikGraph": "[2.5.0, )", + "QuikGraph.Graphviz": "[2.5.0, )" } }, "nncase.importer": { diff --git a/src/Nncase.Core/ReferenceEqualityComparer.cs b/src/Nncase.Core/ReferenceEqualityComparer.cs new file mode 100644 index 000000000..6295014f8 --- /dev/null +++ b/src/Nncase.Core/ReferenceEqualityComparer.cs @@ -0,0 +1,24 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System.Runtime.CompilerServices; + +namespace System.Collections.Generic; + +public sealed class ReferenceEqualityComparer : IEqualityComparer, IEqualityComparer + where T : class +{ + public new bool Equals(object? x, object? y) => ReferenceEquals(x, y); + + public bool Equals(T? x, T? y) => ReferenceEquals(x, y); + + public int GetHashCode(T? obj) + { + // Depending on target framework, RuntimeHelpers.GetHashCode might not be annotated + // with the proper nullability attribute. We'll suppress any warning that might + // result. + return RuntimeHelpers.GetHashCode(obj!); + } + + public int GetHashCode(object obj) => throw new NotImplementedException(); +} diff --git a/src/Nncase.Diagnostics/Diagnostics/ILPrintVisitor.cs b/src/Nncase.Diagnostics/Diagnostics/ILPrintVisitor.cs index b2c08f8d6..4507a3bf1 100644 --- a/src/Nncase.Diagnostics/Diagnostics/ILPrintVisitor.cs +++ b/src/Nncase.Diagnostics/Diagnostics/ILPrintVisitor.cs @@ -574,14 +574,14 @@ protected override string VisitBufferOf(BufferOf expr) /// protected override string VisitGrid(IR.Affine.Grid expr) { + var reads = expr.Reads.AsValueEnumerable().Select(Visit).ToArray(); + var buffers = expr.Buffers.AsValueEnumerable().Select(Visit).ToArray(); if (_names.TryGetValue(expr, out var name)) { return name; } name = AllocateTempVar(expr); - var reads = expr.Reads.AsValueEnumerable().Select(Visit).ToArray(); - var buffers = expr.Buffers.AsValueEnumerable().Select(Visit).ToArray(); _scope.Push(); // 1. For Loop signature diff --git a/src/Nncase.EGraph/packages.lock.json b/src/Nncase.EGraph/packages.lock.json index 8fb2625a9..e55d6a5f8 100644 --- a/src/Nncase.EGraph/packages.lock.json +++ b/src/Nncase.EGraph/packages.lock.json @@ -186,7 +186,8 @@ "dependencies": { "Nncase.Core": "[1.0.0, )", "Nncase.Evaluator": "[1.0.0, )", - "QuikGraph": "[2.5.0, )" + "QuikGraph": "[2.5.0, )", + "QuikGraph.Graphviz": "[2.5.0, )" } }, "CommunityToolkit.HighPerformance": { @@ -254,6 +255,15 @@ "resolved": "2.5.0", "contentHash": "sG+mrPpXwxlXknRK5VqWUGiOmDACa9X+3ftlkQIMgOZUqxVOQSe0+HIU9PTjwqazy0pqSf8MPDXYFGl0GYWcKw==" }, + "QuikGraph.Graphviz": { + "type": "CentralTransitive", + "requested": "[2.5.0, )", + "resolved": "2.5.0", + "contentHash": "pCKpErtHGxUi72OT+2aIg1pdHdUqpqEM5J/i9rmVsEVDE4X0xb1HBPWdxv/FLZmbBjk0ZogZXZttUL3CnAPpNw==", + "dependencies": { + "QuikGraph": "2.5.0" + } + }, "System.CommandLine": { "type": "CentralTransitive", "requested": "[2.0.0-beta4.22272.1, )", diff --git a/src/Nncase.Evaluator/Tensors/GetItem.cs b/src/Nncase.Evaluator/Tensors/GetItem.cs index 4c3d3a76d..dc3ea54e3 100644 --- a/src/Nncase.Evaluator/Tensors/GetItem.cs +++ b/src/Nncase.Evaluator/Tensors/GetItem.cs @@ -97,7 +97,7 @@ private IValue Visit(IValue input, IValue index) private IRType Visit(ITypeInferenceContext context, GetItem target, IRType input, TensorType index) { - IRType ret = new InvalidType("Need Be Reset!"); + IRType ret = new InvalidType("GetItem typeinfer error!"); var indexExpr = context.GetArgument(target, GetItem.Index); switch (input) { diff --git a/src/Nncase.Graph/Graphs/ClusteredBidirectionalGraph.cs b/src/Nncase.Graph/Graphs/ClusteredBidirectionalGraph.cs new file mode 100644 index 000000000..b75c10ed2 --- /dev/null +++ b/src/Nncase.Graph/Graphs/ClusteredBidirectionalGraph.cs @@ -0,0 +1,625 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +#if SUPPORTS_AGGRESSIVE_INLINING +using System.Runtime.CompilerServices; +#endif +using JetBrains.Annotations; +using QuikGraph; +using QuikGraph.Collections; + +namespace Nncase.Graphs; + +/// +/// Mutable clustered adjacency graph data structure. +/// +/// Vertex type. +/// Edge type. +#if SUPPORTS_SERIALIZATION + [Serializable] +#endif +[DebuggerDisplay("VertexCount = {" + nameof(VertexCount) + "}, EdgeCount = {" + nameof(EdgeCount) + "}")] +public class ClusteredBidirectionalGraph + : IEdgeListAndIncidenceGraph, + IMutableBidirectionalGraph, + IClusteredGraph + where TEdge : IEdge +{ + [NotNull] + private readonly List _clusters = new List(); + + /// + /// Initializes a new instance of the class. + /// + /// Graph to wrap. + /// is . + public ClusteredBidirectionalGraph(BidirectionalGraph wrappedGraph) + { + Parent = null; + Wrapped = wrappedGraph ?? throw new ArgumentNullException(nameof(wrappedGraph)); + Collapsed = false; + } + + /// + /// Initializes a new instance of the class. + /// + /// Parent graph. + /// is . + public ClusteredBidirectionalGraph(ClusteredBidirectionalGraph parentGraph) + { + Parent = parentGraph ?? throw new ArgumentNullException(nameof(parentGraph)); + Wrapped = new BidirectionalGraph(parentGraph.AllowParallelEdges); + } + + public event VertexAction VertexAdded + { + add + { + Wrapped.VertexAdded += value; + } + + remove + { + Wrapped.VertexAdded -= value; + } + } + + public event VertexAction VertexRemoved + { + add + { + Wrapped.VertexRemoved += value; + } + + remove + { + Wrapped.VertexRemoved -= value; + } + } + + public event EdgeAction EdgeAdded + { + add + { + Wrapped.EdgeAdded += value; + } + + remove + { + Wrapped.EdgeAdded -= value; + } + } + + public event EdgeAction EdgeRemoved + { + add + { + Wrapped.EdgeRemoved += value; + } + + remove + { + Wrapped.EdgeRemoved -= value; + } + } + + /// + /// Gets Parent graph. + /// + public ClusteredBidirectionalGraph? Parent { get; } + + /// + /// Gets or sets the edge capacity. + /// + public int EdgeCapacity + { + get => Wrapped.EdgeCapacity; + set => Wrapped.EdgeCapacity = value; + } + + /// + /// Gets the type of vertices. + /// + public Type VertexType => typeof(TVertex); + + /// + /// Gets the type of edges. + /// + public Type EdgeType => typeof(TEdge); + + /// + public bool IsDirected => Wrapped.IsDirected; + + /// + public bool AllowParallelEdges => Wrapped.AllowParallelEdges; + + /// + public bool Collapsed { get; set; } + + /// + public IEnumerable Clusters => _clusters.AsEnumerable(); + + /// + public int ClustersCount => _clusters.Count; + + /// + public bool IsVerticesEmpty => Wrapped.IsVerticesEmpty; + + /// + public int VertexCount => Wrapped.VertexCount; + + /// + public virtual IEnumerable Vertices => Wrapped.Vertices; + + /// + public bool IsEdgesEmpty => Wrapped.IsEdgesEmpty; + + /// + public int EdgeCount => Wrapped.EdgeCount; + + /// + public virtual IEnumerable Edges => Wrapped.Edges; + + /// + /// Gets Wrapped graph. + /// + protected BidirectionalGraph Wrapped { get; } + + /// + public bool ContainsEdge(TEdge edge) + { + return Wrapped.ContainsEdge(edge); + } + + /// + /// Adds a new cluster. + /// + /// The added cluster. + public ClusteredBidirectionalGraph AddCluster() + { + var cluster = new ClusteredBidirectionalGraph(this); + _clusters.Add(cluster); + return cluster; + } + + /// + IClusteredGraph IClusteredGraph.AddCluster() + { + return AddCluster(); + } + + /// + public void RemoveCluster(IClusteredGraph graph) + { + if (graph is null) + { + throw new ArgumentNullException(nameof(graph)); + } + + _clusters.Remove(graph); + } + + /// + public bool ContainsVertex(TVertex vertex) + { + return Wrapped.ContainsVertex(vertex); + } + + /// + public bool ContainsEdge(TVertex source, TVertex target) + { + return Wrapped.ContainsEdge(source, target); + } + + /// + public bool TryGetEdge(TVertex source, TVertex target, out TEdge edge) + { + return Wrapped.TryGetEdge(source, target, out edge); + } + + /// + public virtual bool TryGetEdges(TVertex source, TVertex target, out IEnumerable edges) + { + return Wrapped.TryGetEdges(source, target, out edges); + } + + /// + public bool IsOutEdgesEmpty(TVertex vertex) + { + return Wrapped.IsOutEdgesEmpty(vertex); + } + + /// + public int OutDegree(TVertex vertex) + { + return Wrapped.OutDegree(vertex); + } + + /// + public virtual IEnumerable OutEdges(TVertex vertex) + { + return Wrapped.OutEdges(vertex); + } + + /// + public virtual bool TryGetOutEdges(TVertex vertex, out IEnumerable edges) + { + return Wrapped.TryGetOutEdges(vertex, out edges); + } + + /// + public TEdge OutEdge(TVertex vertex, int index) + { + return Wrapped.OutEdge(vertex, index); + } + + /// + /// Adds a vertex to this graph. + /// + /// Vertex to add. + /// True if the vertex was added, false otherwise. + /// is . + public virtual bool AddVertex(TVertex vertex) + { + if (vertex == null) + { + throw new ArgumentNullException(nameof(vertex)); + } + + if (!(Parent is null || Parent.ContainsVertex(vertex))) + { + Parent.AddVertex(vertex); + return Wrapped.AddVertex(vertex); + } + + return Wrapped.AddVertex(vertex); + } + + /// + /// Adds given vertices to this graph. + /// + /// Vertices to add. + /// The number of vertex added. + /// + /// is or at least one of them is . + /// + public virtual int AddVertexRange([NotNull] IEnumerable vertices) + { + if (vertices is null) + { + throw new ArgumentNullException(nameof(vertices)); + } + + TVertex[] verticesArray = vertices.ToArray(); + if (verticesArray.Any(v => v == null)) + { + throw new ArgumentNullException(nameof(vertices), "At least one vertex is null."); + } + + return verticesArray.Count(AddVertex); + } + + /// + /// Removes the given vertex from this graph. + /// + /// Vertex to remove. + /// True if the vertex was removed, false otherwise. + /// is . + public virtual bool RemoveVertex(TVertex vertex) + { + if (vertex == null) + { + throw new ArgumentNullException(nameof(vertex)); + } + + if (!Wrapped.ContainsVertex(vertex)) + { + return false; + } + + RemoveVertexInternal(vertex); + + return true; + } + + /// + /// Removes all vertices matching the given . + /// + /// Predicate to check on each vertex. + /// The number of vertex removed. + /// is . + public int RemoveVertexIf([NotNull] VertexPredicate predicate) + { + if (predicate is null) + { + throw new ArgumentNullException(nameof(predicate)); + } + + var verticesToRemove = new VertexList(); + verticesToRemove.AddRange(Vertices.Where(vertex => predicate(vertex))); + + foreach (TVertex vertex in verticesToRemove) + { + RemoveVertexInternal(vertex); + } + + return verticesToRemove.Count; + } + + /// + /// Adds and its vertices to this graph. + /// + /// The edge to add. + /// True if the edge was added, false otherwise. + /// is . + public virtual bool AddVerticesAndEdge(TEdge edge) + { + if (edge == null) + { + throw new ArgumentNullException(nameof(edge)); + } + + AddVertex(edge.Source); + AddVertex(edge.Target); + return AddEdge(edge); + } + + /// + /// Adds a set of edges (and it's vertices if necessary). + /// + /// Edges to add. + /// The number of edges added. + /// + /// is or at least one of them is . + /// + public int AddVerticesAndEdgeRange([NotNull] IEnumerable edges) + { + if (edges is null) + { + throw new ArgumentNullException(nameof(edges)); + } + + TEdge[] edgesArray = edges.ToArray(); + if (edgesArray.Any(e => e == null)) + { + throw new ArgumentNullException(nameof(edges), "At least one edge is null."); + } + + return edgesArray.Count(AddVerticesAndEdge); + } + + /// + /// Adds the to this graph. + /// + /// An edge. + /// True if the edge was added, false otherwise. + /// is . + public virtual bool AddEdge(TEdge edge) + { + if (edge == null) + { + throw new ArgumentNullException(nameof(edge)); + } + + if (Parent != null && !Parent.ContainsEdge(edge)) + { + Parent.AddEdge(edge); + } + + return Wrapped.AddEdge(edge); + } + + /// + /// Adds a set of edges to this graph. + /// + /// Edges to add. + /// The number of edges successfully added to this graph. + /// + /// is or at least one of them is . + /// + public int AddEdgeRange([NotNull] IEnumerable edges) + { + if (edges is null) + { + throw new ArgumentNullException(nameof(edges)); + } + + TEdge[] edgesArray = edges.ToArray(); + if (edgesArray.Any(e => e == null)) + { + throw new ArgumentNullException(nameof(edges), "At least one edge is null."); + } + + return edgesArray.Count(AddEdge); + } + + /// + /// Removes the from this graph. + /// + /// Edge to remove. + /// True if the was successfully removed, false otherwise. + /// is . + public virtual bool RemoveEdge(TEdge edge) + { + if (edge == null) + { + throw new ArgumentNullException(nameof(edge)); + } + + if (!Wrapped.ContainsEdge(edge)) + { + return false; + } + + RemoveEdgeInternal(edge); + + return true; + } + + /// + /// Removes all edges that match the given . + /// + /// Predicate to check if an edge should be removed. + /// The number of edges removed. + /// is . + public int RemoveEdgeIf([NotNull] EdgePredicate predicate) + { + if (predicate is null) + { + throw new ArgumentNullException(nameof(predicate)); + } + + var edgesToRemove = new EdgeList(); + edgesToRemove.AddRange(Edges.Where(edge => predicate(edge))); + + foreach (TEdge edge in edgesToRemove) + { + RemoveEdgeInternal(edge); + } + + return edgesToRemove.Count; + } + + /// + /// Removes all out-edges of the + /// where the is evaluated to true. + /// + /// The vertex. + /// Predicate to remove edges. + /// The number of removed edges. + /// is . + /// is . + public int RemoveOutEdgeIf(TVertex vertex, [NotNull] EdgePredicate predicate) + { + if (vertex == null) + { + throw new ArgumentNullException(nameof(vertex)); + } + + if (predicate is null) + { + throw new ArgumentNullException(nameof(predicate)); + } + + int edgeToRemoveCount = Wrapped.RemoveOutEdgeIf(vertex, predicate); + Parent?.RemoveOutEdgeIf(vertex, predicate); + + return edgeToRemoveCount; + } + + /// + /// Clears the out-edges of the given . + /// + /// The vertex. + /// is . + public void ClearOutEdges(TVertex vertex) + { + if (vertex == null) + { + throw new ArgumentNullException(nameof(vertex)); + } + + Wrapped.ClearOutEdges(vertex); + } + + /// + /// Clears the vertex and edges. + /// + public void Clear() + { + Wrapped.Clear(); + _clusters.Clear(); + } + + /// + public int RemoveInEdgeIf(TVertex vertex, EdgePredicate predicate) => Wrapped.RemoveInEdgeIf(vertex, predicate); + + /// + public void ClearInEdges(TVertex vertex) => Wrapped.ClearInEdges(vertex); + + /// + public void ClearEdges(TVertex vertex) => Wrapped.ClearEdges(vertex); + + /// + public void TrimEdgeExcess() => Wrapped.TrimEdgeExcess(); + + /// + public bool IsInEdgesEmpty(TVertex vertex) => Wrapped.IsInEdgesEmpty(vertex); + + /// + public int InDegree(TVertex vertex) => Wrapped.InDegree(vertex); + + /// + public IEnumerable InEdges(TVertex vertex) => Wrapped.InEdges(vertex); + + /// + public bool TryGetInEdges(TVertex vertex, out IEnumerable edges) => Wrapped.TryGetInEdges(vertex, out edges); + + /// + public TEdge InEdge(TVertex vertex, int index) => Wrapped.InEdge(vertex, index); + + /// + public int Degree(TVertex vertex) => Wrapped.Degree(vertex); + + private void RemoveChildEdge(TEdge edge) + { + Debug.Assert(edge != null, "edge can't be null"); + + foreach (ClusteredBidirectionalGraph cluster in Clusters) + { + if (cluster.ContainsEdge(edge)) + { + cluster.Wrapped.RemoveEdge(edge); + cluster.RemoveChildEdge(edge); + } + } + } + +#if SUPPORTS_AGGRESSIVE_INLINING + [MethodImpl(MethodImplOptions.AggressiveInlining)] +#endif + private void RemoveEdgeInternal(TEdge edge) + { + Debug.Assert(edge != null, "edge can't be null"); + + RemoveChildEdge(edge); + Wrapped.RemoveEdge(edge); + Parent?.RemoveEdge(edge); + } + + /// + /// Removes the given vertex from clusters. + /// + /// Vertex to remove. + private void RemoveChildVertex(TVertex vertex) + { + Debug.Assert(vertex != null, "vertex can't be null"); + + foreach (ClusteredBidirectionalGraph cluster in Clusters) + { + if (cluster.ContainsVertex(vertex)) + { + cluster.Wrapped.RemoveVertex(vertex); + cluster.RemoveChildVertex(vertex); + } + } + } + +#if SUPPORTS_AGGRESSIVE_INLINING + [MethodImpl(MethodImplOptions.AggressiveInlining)] +#endif + private void RemoveVertexInternal(TVertex vertex) + { + Debug.Assert(vertex != null, "vertex can't be null"); + + RemoveChildVertex(vertex); + Wrapped.RemoveVertex(vertex); + Parent?.RemoveVertex(vertex); + } +} diff --git a/src/Nncase.Graph/Graphs/CondensationGraphAlgorithm.cs b/src/Nncase.Graph/Graphs/CondensationGraphAlgorithm.cs new file mode 100644 index 000000000..d4cedcefa --- /dev/null +++ b/src/Nncase.Graph/Graphs/CondensationGraphAlgorithm.cs @@ -0,0 +1,223 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using QuikGraph; +using QuikGraph.Algorithms; + +namespace Nncase.Graphs; + +public delegate bool IsEdgeCompatibleEventHandler(CondensationGraphAlgorithm sender, IsEdgeCompatibleEventArgs args) + where TVertex : notnull + where TEdge : class, IEdge; + +public delegate bool IsGraphCompatibleEventHandler(CondensationGraphAlgorithm sender, IsGraphCompatibleEventArgs args) + where TVertex : notnull + where TEdge : class, IEdge; + +public sealed class IsGraphCompatibleEventArgs : EventArgs + where TEdge : IEdge +{ + public IsGraphCompatibleEventArgs(TEdge edge) + { + Edge = edge; + } + + public TEdge Edge { get; } +} + +public sealed class IsEdgeCompatibleEventArgs : EventArgs + where TEdge : IEdge +{ + public IsEdgeCompatibleEventArgs(TEdge edge) + { + Edge = edge; + } + + public TEdge Edge { get; } +} + +/// +/// Algorithm that condensate a graph with custom. +/// +/// Vertex type. +/// Edge type. +public sealed class CondensationGraphAlgorithm : AlgorithmBase> + where TEdge : class, IEdge + where TVertex : notnull +{ + /// + /// Initializes a new instance of the class. + /// + /// Graph to visit. + /// is . + public CondensationGraphAlgorithm(IEdgeListAndIncidenceGraph visitedGraph) + : base(visitedGraph) + { + CondensedGraph = new(true); + WrappedGraph = new(false); + ClusteredGraph = new(WrappedGraph); + VertexMap = new Dictionary>(); + EdgeMap = new(ReferenceEqualityComparer.Instance); + } + + public event IsEdgeCompatibleEventHandler? IsEdgeCompatible; + + public event IsGraphCompatibleEventHandler? IsGraphCompatible; + + public BidirectionalGraph, Edge>> CondensedGraph { get; } + + public BidirectionalGraph WrappedGraph { get; } + + public ClusteredBidirectionalGraph ClusteredGraph { get; } + + public Dictionary> VertexMap { get; } + + public Dictionary>> EdgeMap { get; } + + public MergeInfo MergeTwoVertex(TVertex source, TVertex target) + { + var sourceCluster = VertexMap[source]; + var targetCluster = VertexMap[target]; + var mergedCluster = ClusteredGraph.AddCluster(); + + // remove cluster + var (sourceRemovedVertices, sourceRemovedEdges) = RemoveCluster(sourceCluster); + var (targetRemovedVertices, targetRemovedEdges) = RemoveCluster(targetCluster); + + // add merged cluster + foreach (var removedVertex in sourceRemovedVertices.Concat(targetRemovedVertices)) + { + mergedCluster.AddVertex(removedVertex); + VertexMap.Add(removedVertex, mergedCluster); + } + + CondensedGraph.AddVertex(mergedCluster); + foreach (var removedEdge in sourceRemovedEdges.Concat(targetRemovedEdges)) + { + ReAddEdges(removedEdge); + } + + return new(sourceRemovedVertices, sourceRemovedEdges, targetRemovedVertices, targetRemovedEdges, mergedCluster); + } + + public void SplitTwoVertex(MergeInfo mergeInfo) + { + RemoveCluster(mergeInfo.MergedCluster); + var sourceCluster = ClusteredGraph.AddCluster(); + foreach (var vertex in mergeInfo.SourceRemovedVertices) + { + sourceCluster.AddVertex(vertex); + VertexMap.Add(vertex, sourceCluster); + } + + var targetCluster = ClusteredGraph.AddCluster(); + foreach (var vertex in mergeInfo.TargetRemovedVertices) + { + targetCluster.AddVertex(vertex); + VertexMap.Add(vertex, targetCluster); + } + + CondensedGraph.AddVertex(sourceCluster); + CondensedGraph.AddVertex(targetCluster); + foreach (var removedEdge in mergeInfo.SourceRemovedEdges.Concat(mergeInfo.TargetRemovedEdges)) + { + ReAddEdges(removedEdge); + } + } + + /// + protected override void InternalCompute() + { + if (VisitedGraph.VertexCount == 0) + { + return; + } + + // 1. add vertices and edges into graphs. + foreach (var vertex in VisitedGraph.Vertices) + { + var cluster = ClusteredGraph.AddCluster(); + cluster.AddVertex(vertex); + CondensedGraph.AddVertex(cluster); + VertexMap.Add(vertex, cluster); + } + + foreach (var edge in VisitedGraph.Edges) + { + var condensedEdge = new Edge>(VertexMap[edge.Source], VertexMap[edge.Target]); + CondensedGraph.AddEdge(condensedEdge); + ClusteredGraph.AddEdge(edge); + EdgeMap.Add(edge, condensedEdge); + } + + // 2. try to merge vertices. + var dfsVisitEdge = new QuikGraph.Algorithms.Search.EdgeDepthFirstSearchAlgorithm(VisitedGraph); + dfsVisitEdge.InitializeEdge += (edge) => + { + var compatible = IsEdgeCompatible?.Invoke(this, new(edge)) ?? false; + if (!compatible) + { + return; + } + + // modify the condensated graph and clustered graph. + var mergeInfo = MergeTwoVertex(edge.Source, edge.Target); + + compatible = IsGraphCompatible?.Invoke(this, new(edge)) ?? true; + + if (!compatible) + { + SplitTwoVertex(mergeInfo); + } + }; + dfsVisitEdge.Compute(); + } + + private (List RemovedVertices, List RemovedEdges) RemoveCluster(ClusteredBidirectionalGraph cluster) + { + var removedVertices = new List(); + var removedEdges = new List(); + EdgeAction edgeEvent = (e) => + { + CondensedGraph.RemoveEdge(EdgeMap[e]); + EdgeMap.Remove(e); + removedEdges.Add(e); + }; + + ClusteredGraph.EdgeRemoved += edgeEvent; + foreach (var item in cluster.Vertices) + { + ClusteredGraph.RemoveVertex(item); + VertexMap.Remove(item); + removedVertices.Add(item); + } + + ClusteredGraph.EdgeRemoved -= edgeEvent; + ClusteredGraph.RemoveCluster(cluster); + CondensedGraph.RemoveVertex(cluster); + return (removedVertices, removedEdges); + } + + private void ReAddEdges(TEdge removedEdge) + { + var condensedEdge = new Edge>(VertexMap[removedEdge.Source], VertexMap[removedEdge.Target]); + if (!ReferenceEquals(condensedEdge.Source, condensedEdge.Target)) + { + CondensedGraph.AddEdge(condensedEdge); + ClusteredGraph.AddEdge(removedEdge); + } + else + { + condensedEdge.Source.AddEdge(removedEdge); + } + + EdgeMap.Add(removedEdge, condensedEdge); + } + + public record MergeInfo(List SourceRemovedVertices, List SourceRemovedEdges, List TargetRemovedVertices, List TargetRemovedEdges, ClusteredBidirectionalGraph MergedCluster) + { + } +} diff --git a/src/Nncase.Graph/Graphs/GraphExtensions.cs b/src/Nncase.Graph/Graphs/GraphExtensions.cs new file mode 100644 index 000000000..c55eb909b --- /dev/null +++ b/src/Nncase.Graph/Graphs/GraphExtensions.cs @@ -0,0 +1,38 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using QuikGraph; + +namespace Nncase.Graphs; + +public static class GraphExtensions +{ + public static IEnumerable InEdges(this IBidirectionalGraph subGraph, IBidirectionalGraph parentGraph) + where TEdge : IEdge => subGraph.Vertices.Select(v => parentGraph.InEdges(v).Except(subGraph.InEdges(v))).SelectMany(e => e); + + public static IEnumerable OutEdges(this IBidirectionalGraph subGraph, IBidirectionalGraph parentGraph) + where TEdge : IEdge => subGraph.Vertices.Select(v => parentGraph.OutEdges(v)).SelectMany(e => e).Where(e => !subGraph.ContainsVertex(e.Target)); + + public static IEnumerable InVertices(this IBidirectionalGraph graph) + where TEdge : IEdge + => graph.Vertices.Where(v => graph.InDegree(v) == 0); + + public static IEnumerable OutVertices(this IBidirectionalGraph graph) + where TEdge : IEdge + => graph.Vertices.Where(v => graph.OutDegree(v) == 0); + + public static IEnumerable OutVertices(this IBidirectionalGraph subGraph, IBidirectionalGraph parentGraph) + where TEdge : IEdge + { + var outEdges = OutEdges(subGraph, parentGraph).ToArray(); + if (outEdges.Length == 0) + { + return OutVertices(subGraph); + } + + return outEdges.DistinctBy(e => e.Source).Select(e => e.Source); + } +} diff --git a/src/Nncase.Graph/Graphs/GraphVizExtensions.cs b/src/Nncase.Graph/Graphs/GraphVizExtensions.cs new file mode 100644 index 000000000..a4d729356 --- /dev/null +++ b/src/Nncase.Graph/Graphs/GraphVizExtensions.cs @@ -0,0 +1,29 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using Nncase.Graphs; +using QuikGraph; +using QuikGraph.Graphviz; + +namespace Nncase.Graphs; + +public static class GraphVizExtensions +{ + public static void Dump(this IEdgeListGraph graph, string name, Action>? initAlgorithm = null) + where TEdge : IEdge + { + Action> empty = (_) => { }; + using (var stream = Diagnostics.DumpScope.Current.OpenFile($"{name}.dot")) + { + using (var writer = new StreamWriter(stream)) + { + writer.Write(graph.ToGraphviz(initAlgorithm ?? empty)); + } + } + } +} diff --git a/src/Nncase.Graph/Nncase.Graph.csproj b/src/Nncase.Graph/Nncase.Graph.csproj index c6b3674b1..b67b898f9 100644 --- a/src/Nncase.Graph/Nncase.Graph.csproj +++ b/src/Nncase.Graph/Nncase.Graph.csproj @@ -8,6 +8,7 @@ + diff --git a/src/Nncase.Graph/packages.lock.json b/src/Nncase.Graph/packages.lock.json index 0d6ae07d1..afd0ca584 100644 --- a/src/Nncase.Graph/packages.lock.json +++ b/src/Nncase.Graph/packages.lock.json @@ -8,6 +8,15 @@ "resolved": "2.5.0", "contentHash": "sG+mrPpXwxlXknRK5VqWUGiOmDACa9X+3ftlkQIMgOZUqxVOQSe0+HIU9PTjwqazy0pqSf8MPDXYFGl0GYWcKw==" }, + "QuikGraph.Graphviz": { + "type": "Direct", + "requested": "[2.5.0, )", + "resolved": "2.5.0", + "contentHash": "pCKpErtHGxUi72OT+2aIg1pdHdUqpqEM5J/i9rmVsEVDE4X0xb1HBPWdxv/FLZmbBjk0ZogZXZttUL3CnAPpNw==", + "dependencies": { + "QuikGraph": "2.5.0" + } + }, "StyleCop.Analyzers": { "type": "Direct", "requested": "[1.2.0-beta.556, )", diff --git a/src/Nncase.Passes/GraphPartition/ExprGraphConvertor.cs b/src/Nncase.Passes/GraphPartition/ExprGraphConvertor.cs new file mode 100644 index 000000000..e27ebdbce --- /dev/null +++ b/src/Nncase.Passes/GraphPartition/ExprGraphConvertor.cs @@ -0,0 +1,82 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reactive; +using Nncase.IR; +using Nncase.IR.Affine; +using QuikGraph; + +namespace Nncase.Passes.GraphPartition; + +public interface IExprVertex +{ + Expr Expr { get; } + + static abstract IExprVertex Create(Expr expr); +} + +public interface IExprEdge : IEdge + where TExprVertex : IExprVertex +{ + static abstract IExprEdge Create(TExprVertex source, TExprVertex target, int index); +} + +public sealed record ExprVertex : IExprVertex +{ + private ExprVertex(Expr expr) + { + Expr = expr; + } + + public Expr Expr { get; } + + public static IExprVertex Create(Expr expr) => new ExprVertex(expr); + + public bool Equals(ExprVertex? other) + { + return ReferenceEquals(this, other); + } + + public override int GetHashCode() => ReferenceEqualityComparer.Instance.GetHashCode(this); +} + +public sealed record ExprEdge : IExprEdge +{ + private ExprEdge(ExprVertex source, ExprVertex target, int index) + { + Source = source; + Target = target; + Index = index; + } + + public ExprVertex Source { get; } + + public ExprVertex Target { get; } + + public int Index { get; } + + public static IExprEdge Create(ExprVertex source, ExprVertex target, int index) => new ExprEdge(source, target, index); +} + +public class ExprGraphConvertor : ExprVisitor> + where TVertex : IExprVertex + where TEdge : IExprEdge +{ + protected override TVertex DefaultVisitLeaf(Expr expr, IMutableVertexAndEdgeListGraph graph) + { + var target = (TVertex)TVertex.Create(expr); + graph.AddVertex(target); + int count = 0; + foreach (var item in expr.Operands) + { + var source = Visit(item, graph); + var edge = (TEdge)TEdge.Create(source, target, count++); + graph.AddEdge(edge); + } + + return target; + } +} diff --git a/src/Nncase.Passes/GraphPartition/ExprReConstructor.cs b/src/Nncase.Passes/GraphPartition/ExprReConstructor.cs new file mode 100644 index 000000000..a9b15ae42 --- /dev/null +++ b/src/Nncase.Passes/GraphPartition/ExprReConstructor.cs @@ -0,0 +1,134 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reactive; +using Nncase.Graphs; +using Nncase.IR; +using Nncase.IR.Affine; +using QuikGraph; + +namespace Nncase.Passes.GraphPartition; + +public class ExprReConstructor + where TVertex : IExprVertex + where TEdge : class, IExprEdge +{ + public ExprReConstructor(CondensationGraphAlgorithm algo) + { + Algo = algo; + ClusterMemo = new(ReferenceEqualityComparer.Instance); + } + + public CondensationGraphAlgorithm Algo { get; } + + protected Dictionary, Expr> ClusterMemo { get; } + + public Expr Construct() + { + var dfsVisitor = new QuikGraph.Algorithms.TopologicalSort.SourceFirstTopologicalSortAlgorithm, Edge>>(Algo.CondensedGraph); + dfsVisitor.Compute(); + for (var i = 0; i < dfsVisitor.SortedVertices.Length; i++) + { + ClusterMemo.Add(dfsVisitor.SortedVertices[i], OnFinishCluster(dfsVisitor.SortedVertices[i], i)); + } + + return ClusterMemo[dfsVisitor.SortedVertices[^1]]; + } + + protected virtual IEnumerable<(Expr Pre, Expr Post)> GetClusterArgumentPairs(ClusteredBidirectionalGraph cluster) + { + var pairs = new List<(Expr Pre, Expr Post)>(); + foreach (var inEdge in cluster.InEdges(Algo.ClusteredGraph)) + { + // get in Expr + Expr postArg; + var sourceCluster = Algo.VertexMap[inEdge.Source]; + var sourceOutVertices = sourceCluster.OutVertices(Algo.ClusteredGraph).ToArray(); + if (sourceOutVertices.Length == 1) + { + postArg = ClusterMemo[sourceCluster]; + } + else + { + var sourceOutIndex = sourceOutVertices.IndexOf(inEdge.Source); + var postResult = ClusterMemo[sourceCluster]; + postArg = postResult is IR.Tuple tp ? tp.Fields[sourceOutIndex] : IR.F.Tensors.GetItem(postResult, sourceOutIndex); + } + + pairs.Add((inEdge.Source.Expr, postArg)); + } + + return pairs; + } + + protected virtual Expr OnFinishCluster(ClusteredBidirectionalGraph cluster, int sortIndex) + { + return cluster.VertexCount == 1 ? OnAtomCluster(cluster, sortIndex) : OnComplexCluster(cluster, sortIndex); + } + + protected virtual Expr OnAtomCluster(ClusteredBidirectionalGraph cluster, int sortIndex) + { + var pairs = GetClusterArgumentPairs(cluster); + var cloner = new ExprClusterCloner(pairs.ToDictionary(p => p.Pre, p => p.Post, new ReferenceEqualityComparer())); + return cloner.Clone(cluster.Vertices.First().Expr, default); + } + + protected virtual Expr OnComplexCluster(ClusteredBidirectionalGraph cluster, int sortIndex) + { + var pairs = GetClusterArgumentPairs(cluster); + var cloner = new ExprClusterCloner(pairs.ToDictionary(p => p.Pre, p => p.Post, new ReferenceEqualityComparer())); + var outVertices = cluster.OutVertices().ToArray(); + if (outVertices.Length == 1) + { + return cloner.Clone(outVertices[0].Expr, default); + } + else + { + var fields = new List(); + foreach (var outVertex in outVertices) + { + fields.Add(cloner.Clone(outVertex.Expr, default)); + } + + return new IR.Tuple(fields.ToArray()); + } + } +} + +public class ExprClusterCloner : ExprCloner +{ + public ExprClusterCloner(Dictionary extractMemo) + { + ExtractMemo = extractMemo; + } + + public Dictionary ExtractMemo { get; } + + protected override Expr DispatchVisit(Expr expr, Unit context) + { + if (HasVisited(expr, out var result)) + { + return result; + } + + if (ExtractMemo.TryGetValue(expr, out var @param)) + { + return MarkVisited(expr, @param); + } + + return MarkVisited(expr, base.DispatchVisit(expr, context)); + } + + protected override Expr VisitLeafFunction(Function expr, Unit context) => expr; + + protected override Expr VisitLeafVar(Var expr, Unit context) => expr; + + protected override Expr VisitLeafConst(Const expr, Unit context) => expr; + + protected override Expr VisitLeafNone(None expr, Unit context) => expr; + + protected override Expr VisitLeafOp(Op expr, Unit context) => expr; +} diff --git a/src/Nncase.Passes/GraphPartition/GraphConvetor.cs b/src/Nncase.Passes/GraphPartition/GraphConvetor.cs deleted file mode 100644 index cfbe2ad8f..000000000 --- a/src/Nncase.Passes/GraphPartition/GraphConvetor.cs +++ /dev/null @@ -1,453 +0,0 @@ -// Copyright (c) Canaan Inc. All rights reserved. -// Licensed under the Apache license. See LICENSE file in the project root for full license information. - -using System; -using System.Collections.Generic; -using System.Linq; -using System.Reactive; -using Nncase.IR; -using Nncase.IR.Affine; -using QuikGraph; -using QuikGraph.Graphviz; - -namespace Nncase.Passes.GraphPartition; - -public sealed class GraphContext -{ - public Graph Graph { get; set; } = new(); - - public Graph GraphSummary { get; set; } = new(); - - public SortedDictionary SubgraphMap { get; set; } = new(); - - public Dictionary OriginalVertexSubgraphMap { get; set; } = new(); - - public Dictionary SummaryVertexSubgraphMap { get; set; } = new(); - - public Dictionary> VarMap { get; set; } = new(); - - public Dictionary> OutputMap { get; set; } = new(); - - public void MergeSubgraphMap() - { - OriginalVertexSubgraphMap = Graph.Vertices.Select((v, i) => new KeyValuePair(v, i)).ToDictionary(kv => kv.Key, kv => kv.Value); - - // Create subgraph structs - var dfsAssignEdge = new QuikGraph.Algorithms.Search.EdgeDepthFirstSearchAlgorithm(Graph); - dfsAssignEdge.TreeEdge += (edge) => - { - var u = edge.Source; - var v = edge.Target; - - var u_sub_idx = OriginalVertexSubgraphMap[u]; - var v_sub_idx = OriginalVertexSubgraphMap[v]; - - if (u_sub_idx == v_sub_idx) - { - SubgraphMap[u_sub_idx].InteriorEdges.Add(edge); - } - else - { - SubgraphMap[u_sub_idx].OutputEdges.Add(edge); - SubgraphMap[v_sub_idx].InputEdges.Add(edge); - } - }; - dfsAssignEdge.Compute(); - - var dfsVisitEdge = new QuikGraph.Algorithms.Search.EdgeDepthFirstSearchAlgorithm(Graph); - dfsVisitEdge.InitializeEdge += (edge) => - { - var u = edge.Source; - var v = edge.Target; - - if (u.CompatType != v.CompatType) - { - return; - } - - if (OriginalVertexSubgraphMap[u] == OriginalVertexSubgraphMap[v]) - { - return; - } - - var tmpSubgraphMap = new SortedDictionary(SubgraphMap.Comparer); - var tmpvertexSubgraphMap = new Dictionary(OriginalVertexSubgraphMap, OriginalVertexSubgraphMap.Comparer); - foreach (var kvp in SubgraphMap) - { - tmpSubgraphMap[kvp.Key] = new Subgraph(kvp.Value.Index, new List(kvp.Value.Nodes), new List(kvp.Value.InputEdges), new List(kvp.Value.OutputEdges), new List(kvp.Value.InteriorEdges)); - } - - // var vExclusiveInputs = v_subgraph.InputEdges.Where(x => !u_subgraph.OutputEdges.Contains(x)); - // if (SubgraphMap.Values.Any(x => vExclusiveInputs.Any(y => x.OutputEdges.Contains(y) && x.InputEdges.Any(z => u_subgraph.OutputEdges.Contains(z))))) - // { - // return; - // } - var u_subgraph = tmpSubgraphMap[OriginalVertexSubgraphMap[u]]; - var v_subgraph = tmpSubgraphMap[OriginalVertexSubgraphMap[v]]; - - MergeTwoSubgraphs(v_subgraph, u_subgraph, tmpSubgraphMap, tmpvertexSubgraphMap); - - if (!HasCycles(tmpSubgraphMap, tmpvertexSubgraphMap)) - { - SubgraphMap = new SortedDictionary(tmpSubgraphMap, tmpSubgraphMap.Comparer); - OriginalVertexSubgraphMap = tmpvertexSubgraphMap; - } - }; - dfsVisitEdge.Compute(); - } - - public void SummarizeGraph(bool tiling = false) - { - MergeSubgraphMap(); - - GraphSummary = new(); - Dictionary indexMap = new(); - VarMap = SubgraphMap.ToDictionary(x => x.Key, _ => new Dictionary(ReferenceEqualityComparer.Instance)); - OutputMap = SubgraphMap.ToDictionary(x => x.Key, _ => new Dictionary(ReferenceEqualityComparer.Instance)); - - // int count = 0; - foreach (var subgraph in SubgraphMap) - { - var sg = new Graph(); - subgraph.Value.Nodes.ForEach(n => sg.AddVertex(n)); - subgraph.Value.InteriorEdges.ForEach(e => sg.AddEdge(e)); - - // sg.DumpDot(Diagnostics.DumpScope.Current.Directory + $"subgraph_{subgraph.Key}_{count++}.dot"); - var dfsVisitor = new QuikGraph.Algorithms.TopologicalSort.SourceFirstTopologicalSortAlgorithm(sg); - dfsVisitor.Compute(); - for (var vi = 0; vi < dfsVisitor.SortedVertices.Length; vi++) - { - var vertex = dfsVisitor.SortedVertices[vi]; - if (vertex.Expr is Var v) - { - if (!VarMap[subgraph.Key].ContainsKey(v)) - { - VarMap[subgraph.Key].Add(v, new Var(v.CheckedType)); - } - } - else if (subgraph.Value.InputEdges.Any(e => e.Target == vertex)) - { - foreach (var input in subgraph.Value.InputEdges.Where(e => e.Target == vertex && e.Source.Expr is not None).Select(e => e.Source.Expr)) - { - if (input is not Const && !VarMap[subgraph.Key].ContainsKey(input)) - { - if (input.CheckedType is DistributedType d) - { - if (tiling) - { - VarMap[subgraph.Key].Add(input, new Var(d)); - } - else - { - VarMap[subgraph.Key].Add(input, new Var(d.TensorType)); - } - } - else - { - VarMap[subgraph.Key].Add(input, new Var(input.CheckedType)); - } - } - } - } - } - - var u = new Vertex(None.Default, Compat.UNKNOWN); - var outVertices = sg.Vertices.Count() == 1 ? sg.Vertices : subgraph.Value.OutputEdges.Select(e => e.Source).Distinct(); - if (!outVertices.Any()) - { - outVertices = sg.Edges.Where(e => !sg.OutEdges(e.Target).Any()).Select(e => e.Target).Distinct().ToList(); - } - - u.CompatType = sg.Vertices.First().CompatType; - - if (outVertices.Count() == 1) - { - u.Expr = outVertices.First().Expr; - if (u.CompatType == Compat.COMPATIBLE) - { - OutputMap[subgraph.Key].Add(u.Expr, -1); - } - } - else - { - u.Expr = new IR.Tuple(outVertices.Select(x => x.Expr).ToArray()); - if (u.CompatType == Compat.COMPATIBLE) - { - Enumerable.Range(0, outVertices.Count()).ToList().ForEach(i => OutputMap[subgraph.Key].Add(outVertices.ToList()[i].Expr, i)); - } - } - - SummaryVertexSubgraphMap.Add(u, subgraph.Key); - - indexMap.Add(subgraph.Key, u); - GraphSummary.AddVertex(u); - } - - Dictionary edgeMap = new(); - foreach (var subgraph in SubgraphMap) - { - foreach (var edge in subgraph.Value.OutputEdges) - { - var u = indexMap[OriginalVertexSubgraphMap[edge.Source]]; - var v = indexMap[OriginalVertexSubgraphMap[edge.Target]]; - - var newEdge = new Edge(edge.EdgeType, u, v); - GraphSummary.AddEdge(newEdge); - if (edgeMap.ContainsKey(newEdge)) - { - // System.Console.WriteLine("[ERROR] " + edge + " already mapped!"); - } - else - { - edgeMap.Add(newEdge, edge); - } - } - } - } - - private void MergeTwoSubgraphs(Subgraph target, Subgraph source, SortedDictionary subgraphMap, Dictionary vertexSubgraphMap) - { - source.Nodes.ForEach(x => vertexSubgraphMap[x] = target.Index); - - target.Nodes.AddRange(source.Nodes); - - var mergedEdges = source.OutputEdges.Where(s => target.InputEdges.Contains(s)).ToList(); - target.InteriorEdges.AddRange(mergedEdges); - target.InteriorEdges.AddRange(source.InteriorEdges); - - mergedEdges.ForEach(x => target.InputEdges.Remove(x)); - target.InputEdges.AddRange(source.InputEdges); - - source.OutputEdges.ForEach(x => - { - if (!mergedEdges.Contains(x)) - { - target.OutputEdges.Add(x); - } - }); - - subgraphMap.Remove(source.Index); - } - - private bool HasCycles(SortedDictionary subgraphMap, Dictionary vertexSubgraphMap) - { - var graphSummary = new Graph(); - Dictionary indexMap = new(); - foreach (var subgraph in subgraphMap) - { - var u = new Vertex(new Var(), subgraph.Value.Nodes[0].CompatType); - indexMap.Add(subgraph.Key, u); - graphSummary.AddVertex(u); - } - - foreach (var subgraph in subgraphMap) - { - foreach (var edge in subgraph.Value.OutputEdges) - { - var u = indexMap[vertexSubgraphMap[edge.Source]]; - var v = indexMap[vertexSubgraphMap[edge.Target]]; - - var newEdge = new Edge(edge.EdgeType, u, v); - graphSummary.AddEdge(newEdge); - } - } - - List cycles = new(); - var dfs = new QuikGraph.Algorithms.Search.EdgeDepthFirstSearchAlgorithm(graphSummary); - dfs.BackEdge += (edge) => - { - var u = edge.Source; - var v = edge.Target; - cycles.Add(edge); - }; - dfs.Compute(); - - return cycles.Count > 0; - } -} - -public sealed class GraphConvertor : ExprVisitor -{ - private int _nodeCount; - - public GraphConvertor(Func predicate) - { - Predicate = predicate; - } - - public Func Predicate { get; } - - protected override Unit VisitGrid(Grid expr, GraphContext context) - { - foreach (var operand in expr.Reads) - { - Visit(operand, context); - } - - return VisitLeafGrid(expr, context); - } - - protected override Unit VisitLeafGrid(Grid expr, GraphContext context) - { - Vertex target; - if (Predicate(expr)) - { - target = new Vertex(expr, Compat.COMPATIBLE); - } - else - { - target = new Vertex(expr, Compat.INCOMPATIBLE); - } - - context.Graph.AddVertex(target); - context.SubgraphMap.Add(_nodeCount, new Subgraph(_nodeCount, new() { target }, new List(), new List(), new List())); - _nodeCount++; - foreach (var operand in expr.Reads) - { - if (context.Graph.Vertices.Any(v => ReferenceEquals(v.Expr, operand))) - { - var source = context.Graph.Vertices.First(v => ReferenceEquals(v.Expr, operand)); - switch (source.CompatType, target.CompatType) - { - case (Compat.COMPATIBLE, Compat.INCOMPATIBLE): - context.Graph.AddEdge(new Edge(EdgeTypes.C2I, source, target)); - break; - case (Compat.INCOMPATIBLE, Compat.COMPATIBLE): - context.Graph.AddEdge(new Edge(EdgeTypes.I2C, source, target)); - break; - case (Compat.INCOMPATIBLE, Compat.INCOMPATIBLE): - context.Graph.AddEdge(new Edge(EdgeTypes.I2I, source, target)); - break; - default: - context.Graph.AddEdge(new Edge(EdgeTypes.C2C, source, target)); - break; - } - } - } - - return default; - } - - protected override Unit VisitLeafVar(Var expr, GraphContext context) - { - Vertex target; - - target = new Vertex(expr, Compat.INCOMPATIBLE); - - context.Graph.AddVertex(target); - context.SubgraphMap.Add(_nodeCount, new Subgraph(_nodeCount, new() { target }, new List(), new List(), new List())); - _nodeCount++; - - return default; - } - - protected override Unit VisitLeafConst(Const expr, GraphContext context) - { - Vertex target; - target = new Vertex(expr, expr.CheckedType is DistributedType ? Compat.COMPATIBLE : Compat.INCOMPATIBLE); - - context.Graph.AddVertex(target); - context.SubgraphMap.Add(_nodeCount, new Subgraph(_nodeCount, new() { target }, new List(), new List(), new List())); - _nodeCount++; - - return default; - } - - protected override Unit VisitLeafNone(None expr, GraphContext context) - { - Vertex target; - target = new Vertex(expr, Compat.INCOMPATIBLE); - - context.Graph.AddVertex(target); - context.SubgraphMap.Add(_nodeCount, new Subgraph(_nodeCount, new() { target }, new List(), new List(), new List())); - _nodeCount++; - - return default; - } - - protected override Unit VisitLeafCall(Call expr, GraphContext context) - { - Vertex target; - if (Predicate(expr)) - { - target = new Vertex(expr, Compat.COMPATIBLE); - } - else - { - target = new Vertex(expr, Compat.INCOMPATIBLE); - } - - context.Graph.AddVertex(target); - context.SubgraphMap.Add(_nodeCount, new Subgraph(_nodeCount, new() { target }, new List(), new List(), new List())); - _nodeCount++; - foreach (var operand in expr.Arguments) - { - if (context.Graph.Vertices.Any(v => ReferenceEquals(v.Expr, operand))) - { - var source = context.Graph.Vertices.First(v => ReferenceEquals(v.Expr, operand)); - switch (source.CompatType, target.CompatType) - { - case (Compat.COMPATIBLE, Compat.INCOMPATIBLE): - context.Graph.AddEdge(new Edge(EdgeTypes.C2I, source, target)); - break; - case (Compat.INCOMPATIBLE, Compat.COMPATIBLE): - context.Graph.AddEdge(new Edge(EdgeTypes.I2C, source, target)); - break; - case (Compat.INCOMPATIBLE, Compat.INCOMPATIBLE): - context.Graph.AddEdge(new Edge(EdgeTypes.I2I, source, target)); - break; - default: - context.Graph.AddEdge(new Edge(EdgeTypes.C2C, source, target)); - break; - } - } - } - - return default; - } - - protected override Unit VisitLeafTuple(IR.Tuple expr, GraphContext context) - { - Vertex target; - var compatType = context.Graph.Vertices.First(v => ReferenceEquals(v.Expr, expr.Fields[0])).CompatType; - if (!Predicate(expr)) - { - compatType = Compat.INCOMPATIBLE; - } - - target = new Vertex(expr, compatType); - - context.Graph.AddVertex(target); - context.SubgraphMap.Add(_nodeCount, new Subgraph(_nodeCount, new() { target }, new List(), new List(), new List())); - _nodeCount++; - foreach (var field in expr.Fields) - { - if (context.Graph.Vertices.Any(v => ReferenceEquals(v.Expr, field))) - { - var source = context.Graph.Vertices.First(v => ReferenceEquals(v.Expr, field)); - switch (source.CompatType, target.CompatType) - { - case (Compat.COMPATIBLE, Compat.INCOMPATIBLE): - context.Graph.AddEdge(new Edge(EdgeTypes.C2I, source, target)); - break; - case (Compat.INCOMPATIBLE, Compat.COMPATIBLE): - context.Graph.AddEdge(new Edge(EdgeTypes.I2C, source, target)); - break; - case (Compat.INCOMPATIBLE, Compat.INCOMPATIBLE): - context.Graph.AddEdge(new Edge(EdgeTypes.I2I, source, target)); - break; - default: - context.Graph.AddEdge(new Edge(EdgeTypes.C2C, source, target)); - break; - } - } - } - - return default; - } - - protected override Unit DefaultVisitLeaf(Expr expr, GraphContext context) - { - return default; - } -} diff --git a/src/Nncase.Passes/GraphPartition/GraphPartitionTypes.cs b/src/Nncase.Passes/GraphPartition/GraphPartitionTypes.cs deleted file mode 100644 index dfb949cc5..000000000 --- a/src/Nncase.Passes/GraphPartition/GraphPartitionTypes.cs +++ /dev/null @@ -1,116 +0,0 @@ -// Copyright (c) Canaan Inc. All rights reserved. -// Licensed under the Apache license. See LICENSE file in the project root for full license information. - -using System; -using System.Collections.Generic; -using System.IO; -using Nncase.IR; -using QuikGraph; -using QuikGraph.Graphviz; - -namespace Nncase.Passes.GraphPartition; - -public enum Compat -{ - UNKNOWN, - COMPATIBLE, - INCOMPATIBLE, -} - -public enum EdgeTypes -{ - UNKNOWN, - C2C, - I2I, - C2I, - I2C, -} - -public sealed record Vertex -{ - public Vertex(Expr expr, Compat compatType) - { - Expr = expr; - CompatType = compatType; - } - - public Expr Expr { get; set; } - - public Compat CompatType { get; set; } - - public override string ToString() => Expr.ToString(); - - public QuikGraph.Graphviz.Dot.GraphvizColor Color() => CompatType switch - { - Compat.INCOMPATIBLE => QuikGraph.Graphviz.Dot.GraphvizColor.Coral, - Compat.COMPATIBLE => QuikGraph.Graphviz.Dot.GraphvizColor.Olive, - _ => QuikGraph.Graphviz.Dot.GraphvizColor.Cornsilk, - }; - - public bool Equals(Vertex? other) - { - if (other is null) - { - return false; - } - - return ReferenceEquals(Expr, other.Expr) && EqualityComparer.Default.Equals(CompatType, other.CompatType); - } - - public override int GetHashCode() - { - return HashCode.Combine(ReferenceEqualityComparer.Instance.GetHashCode(Expr), CompatType.GetHashCode()); - } -} - -public sealed record Edge : IEdge -{ - public Edge(EdgeTypes edgeType, Vertex source, Vertex target) - { - EdgeType = edgeType; - Source = source; - Target = target; - } - - public EdgeTypes EdgeType { get; set; } - - public Vertex Source { get; set; } - - public Vertex Target { get; set; } - - public bool Equals(Edge? other) - { - if (other is null) - { - return false; - } - - return ReferenceEquals(Source, other.Source) && - ReferenceEquals(Target, other.Target) && - EqualityComparer.Default.Equals(EdgeType, other.EdgeType); - } - - public override int GetHashCode() - { - return HashCode.Combine(ReferenceEqualityComparer.Instance.GetHashCode(Source), ReferenceEqualityComparer.Instance.GetHashCode(Target), EdgeType.GetHashCode()); - } -} - -public sealed class Graph : AdjacencyGraph -{ - public void DumpDot(string fullPathName) - { - using (var writer = new StreamWriter(fullPathName)) - { - var a = this.ToGraphviz(algorithm => - { - algorithm.FormatVertex += (_, args) => args.VertexFormat.Label = args.Vertex.ToString(); - algorithm.FormatVertex += (_, args) => args.VertexFormat.Style = QuikGraph.Graphviz.Dot.GraphvizVertexStyle.Filled; - algorithm.FormatVertex += (_, args) => args.VertexFormat.FillColor = args.Vertex.Color(); - }); - writer.Write(a); - } - } -} - -public sealed record Subgraph(int Index, List Nodes, List InputEdges, List OutputEdges, List InteriorEdges); diff --git a/src/Nncase.Passes/packages.lock.json b/src/Nncase.Passes/packages.lock.json index a11315041..1d1765591 100644 --- a/src/Nncase.Passes/packages.lock.json +++ b/src/Nncase.Passes/packages.lock.json @@ -175,7 +175,8 @@ "dependencies": { "Nncase.Core": "[1.0.0, )", "Nncase.Evaluator": "[1.0.0, )", - "QuikGraph": "[2.5.0, )" + "QuikGraph": "[2.5.0, )", + "QuikGraph.Graphviz": "[2.5.0, )" } }, "CommunityToolkit.HighPerformance": { diff --git a/src/Nncase.Quantization/packages.lock.json b/src/Nncase.Quantization/packages.lock.json index 5d17b5a1c..711114598 100644 --- a/src/Nncase.Quantization/packages.lock.json +++ b/src/Nncase.Quantization/packages.lock.json @@ -195,7 +195,8 @@ "dependencies": { "Nncase.Core": "[1.0.0, )", "Nncase.Evaluator": "[1.0.0, )", - "QuikGraph": "[2.5.0, )" + "QuikGraph": "[2.5.0, )", + "QuikGraph.Graphviz": "[2.5.0, )" } }, "nncase.passes": { diff --git a/src/Nncase.Schedule/Schedule/GraphTiler.cs b/src/Nncase.Schedule/Schedule/GraphTiler.cs index 2f9a2f563..0cba90c4e 100644 --- a/src/Nncase.Schedule/Schedule/GraphTiler.cs +++ b/src/Nncase.Schedule/Schedule/GraphTiler.cs @@ -14,130 +14,11 @@ namespace Nncase.Schedule; -public static class GraphTiler +public class GraphTiler { - public static Expr MCTSTiling(Expr preExpr, string moduleKind, string prefix, Dictionary solveMemo, ICpuTargetOptions targetOptions) - { - var topLevel = targetOptions.MemoryCapacities.Length; - var rootGraph = GraphBuilder.Build(preExpr, topLevel, out var exprMemo); - var rootState = new MCTState(rootGraph, moduleKind, prefix, "0", solveMemo, targetOptions); - var rootNode = new MCTNode(rootState); - var searcher = new MCTSearcher(); - searcher.Search(rootNode); - if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.Tiling)) - { - rootNode.Dump("SearchTree"); - } - - var bestState = (MCTState)searcher.BestMCTNode!.State; - var replaces = new Dictionary(); - foreach (var (oldExpr, v) in exprMemo) - { - if (bestState.Results.TryGetValue(v, out var newExpr)) - { - replaces.Add(oldExpr, newExpr); - } - } - - var cloner = new ReplacingExprCloner(replaces); - return cloner.Clone(preExpr, default); - } + public int DeviceFuncionCount { get; private set; } - public static Expr Tiling(Expr preExpr, string moduleKind, string prefix, Dictionary solveMemo, ICpuTargetOptions targetOptions) - { - var topLevel = targetOptions.MemoryCapacities.Length; - var rootGraph = GraphBuilder.Build(preExpr, topLevel, out var exprMemo); - if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.Tiling)) - { - rootGraph.Dump($"device_func{prefix}_original"); - } - - var (resultMemo, _) = SolveRootGraph(rootGraph, moduleKind, prefix, solveMemo, targetOptions); - var cloner = new ReplacingExprCloner(exprMemo.ToDictionary(kv => (Expr)kv.Key, kv => resultMemo[kv.Value])); - return cloner.Clone(preExpr, default); - } - - public static (Dictionary ResultMemo, long ObjectValue) SolveRootGraph(TieredTileGraph rootGraph, string moduleKind, string prefix, Dictionary solveMemo, ICpuTargetOptions targetOptions) - { - // bufferize root graph. - var bufferGraphMemo = rootGraph.Bufferize(); - if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.Tiling)) - { - bufferGraphMemo[rootGraph].Dump($"device_func{prefix}_original_buffer"); - } - - // condense the root graph. - var condensedGraph = rootGraph.Condense(); - if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.Tiling)) - { - using (var file = Diagnostics.DumpScope.Current.OpenFile($"device_func{prefix}_condensed.dot")) - { - using var writer = new StreamWriter(file); - writer.Write(condensedGraph.ToGraphviz(init => - { - init.FormatVertex += (_, arg) => - { - if (arg.Vertex is TieredTileGraph t) - { - arg.VertexFormat.Label = t.ToString(); - } - }; - })); - } - } - - // convert root graph as tree. - var rootTree = TileNode.FromTileGraph(rootGraph, out var treeGraphMemo); - - var argumentsMemo = bufferGraphMemo[rootGraph].GetInputsOutputs().Inputs.ToDictionary(k => k, k => k.Node.Grid.GetArgument(k.Index)); - var resultMemo = new Dictionary(); - long objectValue = 0; - foreach (var (primGraph, i) in condensedGraph.TopologicalSort().Select((s, i) => (s, i))) - { - using var subscope = new Diagnostics.DumpScope($"device_func{prefix}_{i}", Diagnostics.DumpFlags.Tiling); - var primTree = treeGraphMemo[primGraph]; - HashSet inputBids; - HashSet outputBids; - - if (!solveMemo.TryGetValue(primTree, out var memo)) - { - var result = SolvePrimGraph(primTree, bufferGraphMemo, targetOptions, moduleKind); - (inputBids, outputBids) = (result.Inputs, result.Outputs); - result.ScheduleBuffers(); - var bodyBuilder = T.Sequential(); - result.Visit(primTree, new(bodyBuilder, Array.Empty())); - var parameters = inputBids.Concat(outputBids).Select(k => result.PrimBufferMemo[k]).ToArray(); - var funcBuilder = T.PrimFunc($"device_func{prefix}_{i}", moduleKind, parameters).Body(bodyBuilder); - var primFunc = funcBuilder.Build(); - memo = new(new PrimFunctionWrapper(primFunc, inputBids.Count, inputBids.Concat(outputBids).Select(bid => bid.Node.Grid.GetArgument(bid.Index).CheckedType).ToArray()), result.ObjectiveValue); - solveMemo.Add(primTree, memo); - } - else - { - (inputBids, outputBids) = bufferGraphMemo[primGraph].GetInputsOutputs(); - } - - objectValue += memo.ObjectValue; - var finalCall = new Call(memo.Func, inputBids.Select(bid => argumentsMemo[bid]).ToArray()); - resultMemo.Add(primGraph, finalCall); - - // save the output. - foreach (var outputBid in outputBids) - { - if (!argumentsMemo.TryGetValue(outputBid, out var _)) - { - foreach (var outEdge in bufferGraphMemo[rootGraph].OutEdges(outputBid).Where(e => e.Tag is BufferEdgeKind.Outer)) - { - argumentsMemo.Add(outEdge.Target, finalCall); - } - - argumentsMemo.Add(outputBid, finalCall); - } - } - } - - return (resultMemo, objectValue); - } + public Dictionary SolveMemo { get; } = new Dictionary(new ITreeNodeComparer()); public static TreeSolveResult SolvePrimGraph(TileNode primTree, Dictionary bufferGraphMemo, ICpuTargetOptions targetOptions, string moduleKind) { @@ -703,6 +584,128 @@ public static void DumpAssgin(ITreeNode tree, TreeSolverPrinter printer, Diction } } + public (Dictionary ResultMemo, long ObjectValue) SolveRootGraph(TieredTileGraph rootGraph, string moduleKind, ICpuTargetOptions targetOptions) + { + // bufferize root graph. + var bufferGraphMemo = rootGraph.Bufferize(); + if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.Tiling)) + { + bufferGraphMemo[rootGraph].Dump($"tile_buffer_graph"); + } + + // condense the root graph. + var condensedGraph = rootGraph.Condense(); + if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.Tiling)) + { + using (var file = Diagnostics.DumpScope.Current.OpenFile($"condensed_tile_graph.dot")) + { + using var writer = new StreamWriter(file); + writer.Write(condensedGraph.ToGraphviz(init => + { + init.FormatVertex += (_, arg) => + { + if (arg.Vertex is TieredTileGraph t) + { + arg.VertexFormat.Label = t.ToString(); + } + }; + })); + } + } + + // convert root graph as tree. + var rootTree = TileNode.FromTileGraph(rootGraph, out var treeGraphMemo); + + var argumentsMemo = bufferGraphMemo[rootGraph].GetInputsOutputs().Inputs.ToDictionary(k => k, k => k.Node.Grid.GetArgument(k.Index)); + var resultMemo = new Dictionary(); + long objectValue = 0; + foreach (var (primGraph, i) in condensedGraph.TopologicalSort().Select((s, i) => (s, i))) + { + using var subSubScope = new Diagnostics.DumpScope($"device_func_{DeviceFuncionCount}", Diagnostics.DumpFlags.Tiling); + var primTree = treeGraphMemo[primGraph]; + HashSet inputBids; + HashSet outputBids; + + if (!SolveMemo.TryGetValue(primTree, out var memo)) + { + var result = SolvePrimGraph(primTree, bufferGraphMemo, targetOptions, moduleKind); + (inputBids, outputBids) = (result.Inputs, result.Outputs); + result.ScheduleBuffers(); + var bodyBuilder = T.Sequential(); + result.Visit(primTree, new(bodyBuilder, Array.Empty())); + var parameters = inputBids.Concat(outputBids).Select(k => result.PrimBufferMemo[k]).ToArray(); + var funcBuilder = T.PrimFunc($"device_func_{DeviceFuncionCount++}", moduleKind, parameters).Body(bodyBuilder); + var primFunc = funcBuilder.Build(); + memo = new(new PrimFunctionWrapper(primFunc, inputBids.Count, inputBids.Concat(outputBids).Select(bid => bid.Node.Grid.GetArgument(bid.Index).CheckedType).ToArray()), result.ObjectiveValue); + SolveMemo.Add(primTree, memo); + } + else + { + (inputBids, outputBids) = bufferGraphMemo[primGraph].GetInputsOutputs(); + } + + objectValue += memo.ObjectValue; + var finalCall = new Call(memo.Func, inputBids.Select(bid => argumentsMemo[bid]).ToArray()); + resultMemo.Add(primGraph, finalCall); + + // save the output. + foreach (var outputBid in outputBids) + { + if (!argumentsMemo.TryGetValue(outputBid, out var _)) + { + foreach (var outEdge in bufferGraphMemo[rootGraph].OutEdges(outputBid).Where(e => e.Tag is BufferEdgeKind.Outer)) + { + argumentsMemo.Add(outEdge.Target, finalCall); + } + + argumentsMemo.Add(outputBid, finalCall); + } + } + } + + return (resultMemo, objectValue); + } + + public Expr Tile(Expr preExpr, string moduleKind, ICpuTargetOptions targetOptions) + { +#if true + var topLevel = targetOptions.MemoryCapacities.Length; + var rootGraph = GraphBuilder.Build(preExpr, topLevel, out var exprMemo); + if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.Tiling)) + { + rootGraph.Dump($"tile_graph"); + } + + var (resultMemo, _) = SolveRootGraph(rootGraph, moduleKind, targetOptions); + var cloner = new ReplacingExprCloner(exprMemo.ToDictionary(kv => (Expr)kv.Key, kv => resultMemo[kv.Value])); + return cloner.Clone(preExpr, default); +#else + var topLevel = targetOptions.MemoryCapacities.Length; + var rootGraph = GraphBuilder.Build(preExpr, topLevel, out var exprMemo); + var rootState = new MCTState(rootGraph, moduleKind, "0", SolveMemo, targetOptions); + var rootNode = new MCTNode(rootState); + var searcher = new MCTSearcher(); + searcher.Search(rootNode); + if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.Tiling)) + { + rootNode.Dump("SearchTree"); + } + + var bestState = (MCTState)searcher.BestMCTNode!.State; + var replaces = new Dictionary(); + foreach (var (oldExpr, v) in exprMemo) + { + if (bestState.Results.TryGetValue(v, out var newExpr)) + { + replaces.Add(oldExpr, newExpr); + } + } + + var cloner = new ReplacingExprCloner(replaces); + return cloner.Clone(preExpr, default); +#endif + } + private static void DumpGantt(Dictionary nodeBufferSizes, Dictionary> nodeBufferLiveness, TileNode primTree, int storeLevel) { string GetStartStr(string name, int start) => $"[{name}] starts D+{start}"; diff --git a/src/Nncase.Schedule/Schedule/TileGraph/GraphMCTS.cs b/src/Nncase.Schedule/Schedule/TileGraph/GraphMCTS.cs index 856e743fa..57d74bd66 100644 --- a/src/Nncase.Schedule/Schedule/TileGraph/GraphMCTS.cs +++ b/src/Nncase.Schedule/Schedule/TileGraph/GraphMCTS.cs @@ -19,32 +19,29 @@ public sealed class MCTState : IEnvironmentState { private readonly string _path = string.Empty; + private readonly GraphTiler _graphTiler; + private readonly List _mergePoints = new(); private readonly List _legalIndex = new(); - private readonly Dictionary _solveMemo; - private readonly string _moduleKind; - private readonly string _prefix; - private readonly ICpuTargetOptions _targetOptions; private readonly TieredTileGraph _graph; private int _permformCount; - public MCTState(TieredTileGraph graph, string moduleKind, string prefix, string searchPath, Dictionary solveMemo, ICpuTargetOptions targetOptions) + public MCTState(TieredTileGraph graph, string moduleKind, string searchPath, GraphTiler graphTiler, ICpuTargetOptions targetOptions) { _graph = graph; _moduleKind = moduleKind; - _prefix = prefix; - _solveMemo = solveMemo; _targetOptions = targetOptions; _mergePoints.AddRange(graph.GetMergePoints()); _legalIndex.AddRange(Enumerable.Range(0, _mergePoints.Count)); _path = searchPath; + _graphTiler = graphTiler; Results = new(new LeafTileGraphComparer()); } @@ -70,7 +67,7 @@ public int LegalActions() var newGraph = _graph.Clone(); if (newGraph.Merge(mergePoint)) { - return new MCTState(newGraph, _moduleKind, _prefix, $"{_path}.{_permformCount}", _solveMemo, _targetOptions); + return new MCTState(newGraph, _moduleKind, $"{_path}.{_permformCount}", _graphTiler, _targetOptions); } return null; @@ -83,7 +80,7 @@ public double RollOut() using var scope = new Diagnostics.DumpScope($"RollOut{_path}"); try { - var res = GraphTiler.SolveRootGraph(_graph, _moduleKind, _prefix, _solveMemo, _targetOptions); + var res = _graphTiler.SolveRootGraph(_graph, _moduleKind, _targetOptions); ObjectValue = res.ObjectValue; foreach (var item in res.ResultMemo) { diff --git a/src/Nncase.Schedule/Transforms/AutoTilePass.cs b/src/Nncase.Schedule/Transforms/AutoTilePass.cs index 391b80eb1..32cab4bb1 100644 --- a/src/Nncase.Schedule/Transforms/AutoTilePass.cs +++ b/src/Nncase.Schedule/Transforms/AutoTilePass.cs @@ -7,10 +7,13 @@ using System.Text; using System.Threading.Tasks; using NetFabric.Hyperlinq; +using Nncase.Graphs; using Nncase.IR; using Nncase.IR.Affine; using Nncase.Passes.GraphPartition; using Nncase.Schedule; +using QuikGraph; +using QuikGraph.Algorithms; namespace Nncase.Passes.Transforms; @@ -31,98 +34,208 @@ public AutoTilePass(string moduleKind, CompileOptions compileOptions) protected override Task RunCoreAsync(IRModule input, RunPassContext context) { - var memo = new Dictionary(); + var tiler = new GraphTiler(); var funcNums = input.Functions.Count; for (int i = 0; i < funcNums; i++) { - var post = Rewrite(input.Functions[i], i, memo); + var pre = input.Functions[i]; + using var scope = new Diagnostics.DumpScope(pre.Name); + var post = Rewrite(pre, tiler); input.Replace(i, post); } return Task.FromResult(input); } - private BaseFunction Rewrite(BaseFunction pre, int funcNumber, Dictionary memo) + private BaseFunction Rewrite(BaseFunction pre, GraphTiler tiler) { if (!(pre is IR.Fusion fusion && fusion.ModuleKind == ModuleKind)) { return pre; } - // Function post; - var ctx = new GraphContext(); - var convertor = new GraphConvertor(x => x switch + var funcName = pre.Name; + + // 1. convert to quikgraph + var graph = new BidirectionalGraph(false); { - Grid => true, - IR.Tuple tp => tp.Fields.AsValueEnumerable().All(f => f is Grid), - _ => false, - }); - convertor.Visit(fusion.Body, ctx); + var convertor = new AutoTileExprGraphConvertor(); + convertor.Visit(fusion.Body, graph); + } - ctx.SummarizeGraph(true); + // 2. perform condensation + var condenseAlgo = new CondensationGraphAlgorithm(graph); + condenseAlgo.IsEdgeCompatible += (algo, arg) => + { + return (arg.Edge.Source.Expr, arg.Edge.Target.Expr) switch + { + (Grid, Grid) => true, + (Grid, IR.Tuple tp) => tp.Fields.AsValueEnumerable().All(x => x is Grid), + _ => false, + }; + }; - var dfsVisitor = new QuikGraph.Algorithms.TopologicalSort.SourceFirstTopologicalSortAlgorithm(ctx.GraphSummary); - dfsVisitor.Compute(); - var exprMemo = new Dictionary(ReferenceEqualityComparer.Instance); - for (var subFuncNumber = 0; subFuncNumber < dfsVisitor.SortedVertices.Length; subFuncNumber++) + condenseAlgo.IsGraphCompatible += (algo, edge) => { - var vertex = dfsVisitor.SortedVertices[subFuncNumber]; - var subgraph = ctx.SubgraphMap[ctx.SummaryVertexSubgraphMap[vertex]]; - if (vertex.CompatType == Compat.INCOMPATIBLE) + return algo.CondensedGraph.IsDirectedAcyclicGraph(); + }; + + condenseAlgo.Compute(); + + if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.Rewrite)) + { + condenseAlgo.CondensedGraph.Dump($"Condensed", init => { }); + condenseAlgo.ClusteredGraph.Dump($"Cluster", algo => { - var sg = new Graph(); - subgraph.Nodes.ForEach(n => sg.AddVertex(n)); - subgraph.InteriorEdges.ForEach(e => sg.AddEdge(e)); - - // sg.DumpDot(DumpScope.Current.Directory + $"_Incompatible_{subgraph.Index}_{vi}.dot"); - var sgVisitor = new QuikGraph.Algorithms.TopologicalSort.SourceFirstTopologicalSortAlgorithm(sg); - sgVisitor.Compute(); - foreach (var v in sgVisitor.SortedVertices) + algo.FormatVertex += (s, arg) => { - var expr = v.Expr switch - { - Call c => c.With(arguments: c.Arguments.AsValueEnumerable().Select(arg => exprMemo[arg]).ToArray()), - IR.Tuple t => t.With(fields: t.Fields.AsValueEnumerable().Select(arg => exprMemo[arg]).ToArray()), - _ => v.Expr, - }; - exprMemo.Add(v.Expr, expr); - } - } - else + arg.VertexFormat.Label = $"{arg.Vertex.Expr.GetType().Name}"; + }; + }); + } + + // 3. reconstruction + var constructor = new AutoTileReConstructor(tiler, ModuleKind, CompileOptions, condenseAlgo); + var post = constructor.Construct(); + return fusion.With(fusion.Name, fusion.ModuleKind, post, fusion.Parameters.ToArray()); + } +} + +internal sealed class AutoTileExprGraphConvertor : ExprGraphConvertor +{ + protected override ExprVertex VisitGrid(Grid expr, IMutableVertexAndEdgeListGraph context) + { + foreach (var read in expr.Reads) + { + Visit(read, context); + } + + return VisitLeafGrid(expr, context); + } + + protected override ExprVertex VisitLeafGrid(Grid expr, IMutableVertexAndEdgeListGraph graph) + { + var target = (ExprVertex)ExprVertex.Create(expr); + graph.AddVertex(target); + int count = 0; + foreach (var item in expr.Reads) + { + var source = Visit(item, graph); + var edge = (ExprEdge)ExprEdge.Create(source, target, count++); + graph.AddEdge(edge); + } + + return target; + } +} + +internal sealed class AutoTileReConstructor : ExprReConstructor +{ + public AutoTileReConstructor(GraphTiler tiler, string moduleKind, CompileOptions compileOptions, CondensationGraphAlgorithm algo) + : base(algo) + { + Tiler = tiler; + ModuleKind = moduleKind; + CompileOptions = compileOptions; + } + + public GraphTiler Tiler { get; } + + public string ModuleKind { get; } + + public CompileOptions CompileOptions { get; } + + protected override Expr OnAtomCluster(ClusteredBidirectionalGraph cluster, int sortIndex) + { + using var subscope = new Diagnostics.DumpScope($"cluster_{sortIndex}", Diagnostics.DumpFlags.Tiling); + var pairs = GetClusterArgumentPairs(cluster); + var vertex = cluster.Vertices.First(); + var expr = vertex.Expr; + if (expr is Grid) + { + var extractDict = new Dictionary(ReferenceEqualityComparer.Instance); + var argumentDict = new Dictionary(ReferenceEqualityComparer.Instance); + foreach (var (pre, post) in pairs) { - var si = ctx.SummaryVertexSubgraphMap[vertex]; - var cloner = new ReplacingExprCloner(ctx.VarMap[si].ToDictionary(kv => kv.Key, kv => (Expr)kv.Value)); - var clonedCall = cloner.Clone(vertex.Expr, default); // replaces some exprs that are in the subgraph with var, avoid tiling the grids out of the subgraph. - using var scope = new Diagnostics.DumpScope($"tiling_func{funcNumber}_subfunc{subFuncNumber}"); -#if false - var tiledCall = GraphTiler.MCTSTiling(clonedCall, ModuleKind, $"{funcNumber}_{subFuncNumber}", memo, (ICpuTargetOptions)CompileOptions.TargetOptions); -#else - var tiledCall = GraphTiler.Tiling(clonedCall, ModuleKind, $"{funcNumber}_{subFuncNumber}", memo, (ICpuTargetOptions)CompileOptions.TargetOptions); -#endif - - var varMap = ctx.VarMap[si].ToDictionary(kv => (Expr)kv.Value, kv => exprMemo[kv.Key]); - var substitutor = new Mutators.Substitutor(e => + if (pre is Const) { - if (varMap.TryGetValue(e, out var arg)) - { - return arg; - } - - return null; - }); + continue; + } - var cleanedCall = substitutor.Rewrite(tiledCall, default); - if (ctx.OutputMap[subgraph.Index].Count > 1) + var @var = new Var(pre.CheckedType); + var added = extractDict.TryAdd(pre, @var); + if (added) { - ctx.OutputMap[subgraph.Index].ToList().ForEach(e => exprMemo.Add(e.Key, IR.F.Tensors.GetItem(cleanedCall, e.Value))); + argumentDict.Add(@var, post); } - else + } + + var cloner = new ExprClusterCloner(extractDict); + Expr cloned = cloner.Clone(expr, default); + var tiled = Tiler.Tile(cloned, ModuleKind, (ICpuTargetOptions)CompileOptions.TargetOptions); + var substitutor = new Mutators.Substitutor(e => + { + if (e is Var v && argumentDict.TryGetValue(v, out var arg)) { - exprMemo.Add(ctx.OutputMap[subgraph.Index].Keys.First(), cleanedCall); + return arg; } + + return null; + }); + + var substited = substitutor.Rewrite(tiled, default); + return substited; + } + else + { + var cloner = new ExprClusterCloner(pairs.ToDictionary(p => p.Pre, p => p.Post, new ReferenceEqualityComparer())); + return cloner.Clone(expr, default); + } + } + + protected override Expr OnComplexCluster(ClusteredBidirectionalGraph cluster, int sortIndex) + { + using var subscope = new Diagnostics.DumpScope($"cluster_{sortIndex}", Diagnostics.DumpFlags.Tiling); + var pairs = GetClusterArgumentPairs(cluster); + var extractDict = new Dictionary(ReferenceEqualityComparer.Instance); + var argumentDict = new Dictionary(ReferenceEqualityComparer.Instance); + foreach (var (pre, post) in pairs) + { + if (pre is Const) + { + continue; + } + + var @var = new Var(pre.CheckedType); + var added = extractDict.TryAdd(pre, @var); + if (added) + { + argumentDict.Add(@var, post); } } - return fusion.With(fusion.Name, fusion.ModuleKind, exprMemo[fusion.Body], fusion.Parameters.ToArray()); + // todo sometimes internal grid have outside dependence, so we can't fuse it when tiling. + var cloner = new ExprClusterCloner(extractDict); + var outVertices = cluster.OutVertices(Algo.ClusteredGraph).ToArray(); + var clones = new List(); + foreach (var outVertex in outVertices) + { + clones.Add(cloner.Clone(outVertex.Expr, default)); + } + + Expr cloned = clones.Count == 1 ? clones[0] : new IR.Tuple(clones.ToArray()); + var tiled = Tiler.Tile(cloned, ModuleKind, (ICpuTargetOptions)CompileOptions.TargetOptions); + var substitutor = new Mutators.Substitutor(e => + { + if (e is Var v && argumentDict.TryGetValue(v, out var arg)) + { + return arg; + } + + return null; + }); + + var substited = substitutor.Rewrite(tiled, default); + return substited; } } diff --git a/src/Nncase.Schedule/packages.lock.json b/src/Nncase.Schedule/packages.lock.json index 02e78a2f5..43f121dee 100644 --- a/src/Nncase.Schedule/packages.lock.json +++ b/src/Nncase.Schedule/packages.lock.json @@ -166,7 +166,8 @@ "dependencies": { "Nncase.Core": "[1.0.0, )", "Nncase.Evaluator": "[1.0.0, )", - "QuikGraph": "[2.5.0, )" + "QuikGraph": "[2.5.0, )", + "QuikGraph.Graphviz": "[2.5.0, )" } }, "nncase.passes": { diff --git a/src/Nncase.Studio/packages.lock.json b/src/Nncase.Studio/packages.lock.json index d37425bf0..e395e233a 100644 --- a/src/Nncase.Studio/packages.lock.json +++ b/src/Nncase.Studio/packages.lock.json @@ -978,7 +978,8 @@ "dependencies": { "Nncase.Core": "[1.0.0, )", "Nncase.Evaluator": "[1.0.0, )", - "QuikGraph": "[2.5.0, )" + "QuikGraph": "[2.5.0, )", + "QuikGraph.Graphviz": "[2.5.0, )" } }, "nncase.importer": { diff --git a/src/Nncase.Targets/packages.lock.json b/src/Nncase.Targets/packages.lock.json index c59a0f2c0..ed2893f09 100644 --- a/src/Nncase.Targets/packages.lock.json +++ b/src/Nncase.Targets/packages.lock.json @@ -174,7 +174,8 @@ "dependencies": { "Nncase.Core": "[1.0.0, )", "Nncase.Evaluator": "[1.0.0, )", - "QuikGraph": "[2.5.0, )" + "QuikGraph": "[2.5.0, )", + "QuikGraph.Graphviz": "[2.5.0, )" } }, "nncase.io": { diff --git a/src/Nncase.Tests.TestFixture/packages.lock.json b/src/Nncase.Tests.TestFixture/packages.lock.json index 3f365de3b..fe5aba1ab 100644 --- a/src/Nncase.Tests.TestFixture/packages.lock.json +++ b/src/Nncase.Tests.TestFixture/packages.lock.json @@ -765,7 +765,8 @@ "dependencies": { "Nncase.Core": "[1.0.0, )", "Nncase.Evaluator": "[1.0.0, )", - "QuikGraph": "[2.5.0, )" + "QuikGraph": "[2.5.0, )", + "QuikGraph.Graphviz": "[2.5.0, )" } }, "nncase.importer": { diff --git a/src/Nncase.Tests/Affine/UnitTestTileGraph.cs b/src/Nncase.Tests/Affine/UnitTestTileGraph.cs index e58bdd896..360685cd6 100644 --- a/src/Nncase.Tests/Affine/UnitTestTileGraph.cs +++ b/src/Nncase.Tests/Affine/UnitTestTileGraph.cs @@ -341,7 +341,10 @@ public void TestSolveTileGraph(Func functor, IntMergePoint[] mergePoin tileGraph.Dump($"g{count}_m"); #endif - Schedule.GraphTiler.SolveRootGraph(tileGraph, Targets.CPUTarget.Kind, count.ToString(), new(), targetOptions); + var tiler = new Schedule.GraphTiler(); + using var scope = new Diagnostics.DumpScope($"{count}"); + var result = tiler.Tile(post, Nncase.Targets.CPUTarget.Kind, (ICpuTargetOptions)CompileOptions.TargetOptions); + action(result); } [Theory] @@ -357,12 +360,14 @@ public void TestMCTS(Func functor, int count) builder.Visit(post); var tileGraph = builder.RootGraph; - var memo = new Dictionary(new ITreeNodeComparer()); - var state = new MCTState(tileGraph, "cpu", count.ToString(), string.Empty, memo, targetOptions); + var tiler = new Schedule.GraphTiler(); + var state = new MCTState(tileGraph, "cpu", count.ToString(), tiler, targetOptions); var rootNode = new MCTNode(state); var searcher = new MCTSearcher(); searcher.Search(rootNode); +#if DEBUG rootNode.Dump("mct"); +#endif } [Theory] diff --git a/src/Nncase.Tests/Graphs/UnitTestCondensationGraphAlgorithm.cs b/src/Nncase.Tests/Graphs/UnitTestCondensationGraphAlgorithm.cs new file mode 100644 index 000000000..7be08cb7a --- /dev/null +++ b/src/Nncase.Tests/Graphs/UnitTestCondensationGraphAlgorithm.cs @@ -0,0 +1,190 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Runtime.CompilerServices; +using Nncase.Graphs; +using Nncase.IR; +using Nncase.IR.Math; +using Nncase.Passes; +using Nncase.Tests.TestFixture; +using QuikGraph; +using QuikGraph.Algorithms; +using QuikGraph.Graphviz; +using Xunit; + +namespace Nncase.Tests.Graphs; + +[AutoSetupTestMethod(InitSession = true)] +public class UnitTestCondensationGraphAlgorithm : TestClassBase +{ + [Fact] + public void TestSimpleBidirectionalGraph() + { + var graph = GraphCase0(); + Dump(graph, "biGraph"); + } + + [Fact] + public void TestMerge0() + { + var graph = GraphCase0(); + var condseAlgo = new CondensationGraphAlgorithm>(graph); + condseAlgo.Compute(); + condseAlgo.MergeTwoVertex(0, 1); + condseAlgo.MergeTwoVertex(2, 3); + + Dump(condseAlgo.ClusteredGraph, "ClusteredGraph"); + Dump(condseAlgo.CondensedGraph, "CondensedGraph"); + Assert.Equal(3, condseAlgo.CondensedGraph.VertexCount); + Assert.Equal(2, condseAlgo.CondensedGraph.OutDegree(condseAlgo.VertexMap[0])); + Assert.Equal(2, condseAlgo.CondensedGraph.InDegree(condseAlgo.VertexMap[2])); + Assert.Equal(condseAlgo.VertexMap[0], condseAlgo.VertexMap[1]); + Assert.Equal(condseAlgo.VertexMap[2], condseAlgo.VertexMap[3]); + } + + [Fact] + public void TestMerge1() + { + var graph = new BidirectionalGraph>(false); + graph.AddVerticesAndEdge(new(0, 1)); + graph.AddVerticesAndEdge(new(2, 1)); + graph.AddVerticesAndEdge(new(1, 3)); + graph.AddVerticesAndEdge(new(4, 3)); + var condseAlgo = new CondensationGraphAlgorithm>(graph); + condseAlgo.Compute(); + condseAlgo.MergeTwoVertex(0, 1); + condseAlgo.MergeTwoVertex(2, 1); + condseAlgo.MergeTwoVertex(3, 1); + + Dump(condseAlgo.ClusteredGraph, "ClusteredGraph"); + Dump(condseAlgo.CondensedGraph, "CondensedGraph"); + Assert.Equal(2, condseAlgo.CondensedGraph.VertexCount); + Assert.Equal(4, condseAlgo.VertexMap[0].VertexCount); + Assert.Equal(1, condseAlgo.VertexMap[4].VertexCount); + } + + [Fact] + public void TestSplit0() + { + var graph = new BidirectionalGraph>(false); + graph.AddVerticesAndEdge(new(0, 1)); + graph.AddVerticesAndEdge(new(2, 1)); + graph.AddVerticesAndEdge(new(1, 3)); + graph.AddVerticesAndEdge(new(4, 3)); + var condseAlgo = new CondensationGraphAlgorithm>(graph); + condseAlgo.Compute(); + condseAlgo.MergeTwoVertex(0, 1); + condseAlgo.MergeTwoVertex(2, 1); + condseAlgo.MergeTwoVertex(3, 4); + var info = condseAlgo.MergeTwoVertex(1, 3); + + Dump(condseAlgo.ClusteredGraph, "MergeClusteredGraph"); + Dump(condseAlgo.CondensedGraph, "MergeCondensedGraph"); + Assert.Equal(1, condseAlgo.CondensedGraph.VertexCount); + + condseAlgo.SplitTwoVertex(info); + Dump(condseAlgo.ClusteredGraph, "SplitClusteredGraph"); + Dump(condseAlgo.CondensedGraph, "SplitCondensedGraph"); + Assert.Equal(2, condseAlgo.CondensedGraph.VertexCount); + Assert.Equal(3, condseAlgo.VertexMap[0].VertexCount); + Assert.Equal(2, condseAlgo.VertexMap[4].VertexCount); + } + + [Fact] + public void TestCondensation0() + { + var graph = GraphCase0(); + var condseAlgo = new CondensationGraphAlgorithm>(graph); + condseAlgo.IsEdgeCompatible += (s, arg) => + { + if (arg.Edge is { Source: 1, Target: 3 }) + { + return true; + } + + return false; + }; + condseAlgo.Compute(); + + Dump(condseAlgo.ClusteredGraph, "ClusteredGraph"); + Dump(condseAlgo.CondensedGraph, "CondensedGraph"); + } + + [Fact] + public void TestCondensation1() + { + var graph = GraphCase0(); + var condseAlgo = new CondensationGraphAlgorithm>(graph); + condseAlgo.IsEdgeCompatible += (s, arg) => + { + if (arg.Edge is { Source: 1, Target: 3 } or { Source: 0, Target: 1 }) + { + return true; + } + + return false; + }; + condseAlgo.Compute(); + + Dump(condseAlgo.ClusteredGraph, "ClusteredGraph"); + Dump(condseAlgo.CondensedGraph, "CondensedGraph"); + + Assert.False(condseAlgo.CondensedGraph.IsDirectedAcyclicGraph()); + } + + [Fact] + public void TestCondensation2() + { + var graph = GraphCase0(); + var condseAlgo = new CondensationGraphAlgorithm>(graph); + condseAlgo.IsEdgeCompatible += (s, arg) => + { + if (arg.Edge is { Source: 1, Target: 3 } or { Source: 0, Target: 1 }) + { + return true; + } + + return false; + }; + condseAlgo.IsGraphCompatible += (s, arg) => + { + return s.CondensedGraph.IsDirectedAcyclicGraph(); + }; + condseAlgo.Compute(); + + Dump(condseAlgo.ClusteredGraph, "ClusteredGraph"); + Dump(condseAlgo.CondensedGraph, "CondensedGraph"); + + // NOTE actually we can't merge 1,3 and 0,1 simultaneously, but according to the dfs order, finally the 0,1 will be merge. + Assert.Equal(4, condseAlgo.CondensedGraph.VertexCount); + } + + private BidirectionalGraph> GraphCase0() + { + var biGraph = new BidirectionalGraph>(false); + biGraph.AddVerticesAndEdge(new(0, 1)); + biGraph.AddVerticesAndEdge(new(0, 2)); + biGraph.AddVerticesAndEdge(new(1, 3)); + biGraph.AddVerticesAndEdge(new(2, 3)); + biGraph.AddVerticesAndEdge(new(3, 4)); + return biGraph; + } + + private void Dump(IEdgeListGraph graph, string name) + where TEdge : IEdge + { +#if DEBUG + using (var stream = Diagnostics.DumpScope.Current.OpenFile($"{name}.dot")) + { + using (var writer = new StreamWriter(stream)) + { + writer.Write(graph.ToGraphviz()); + } + } +#endif + } +} diff --git a/src/Nncase.Tests/Rewrite/Fusion/UnitTestGraphPartition.cs b/src/Nncase.Tests/Rewrite/Fusion/UnitTestGraphPartition.cs index db71d4625..adf040bad 100644 --- a/src/Nncase.Tests/Rewrite/Fusion/UnitTestGraphPartition.cs +++ b/src/Nncase.Tests/Rewrite/Fusion/UnitTestGraphPartition.cs @@ -113,9 +113,10 @@ public async Task TestLineSmaeModuleC() [Fact] public async Task TestLineDiffModuleC2I() { - var inType = new TensorType(DataTypes.Float32, new int[] { 1, 32, 32 }); - var input = new Var("input", inType); - var main = new Function("main", IR.F.CPU.Boxing(IR.F.Math.Abs(IR.F.CPU.Boxing(input, new DistributedType(inType, new SBP[] { SBP.B }, new(new[] { 2 }, "t")))), inType), input); + var ttype = new TensorType(DataTypes.Float32, new int[] { 1, 32, 32 }); + var input = new Var("input", ttype); + var unary = IR.F.CPU.Boxing(input, new DistributedType(ttype, new[] { SBP.B }, new(new[] { 1 }, "b"))); + var main = new Function("main", IR.F.Math.Abs(IR.F.CPU.Boxing(unary, ttype)), input); Assert.True(CompilerServices.InferenceType(main)); @@ -149,9 +150,10 @@ public async Task TestLineDiffModuleC2I() [Fact] public async Task TestLineDiffModuleI2C() { - var inType = new TensorType(DataTypes.Float32, new int[] { 1, 32, 32 }); - var input = new Var("input", inType); - var main = new Function("main", IR.F.CPU.Boxing(IR.F.Math.Abs(IR.F.CPU.Boxing(input, new DistributedType(inType, new SBP[] { SBP.B }, new(new[] { 2 }, "t")))), inType), input); + var ttype = new TensorType(DataTypes.Float32, new int[] { 1, 32, 32 }); + var input = new Var("input", ttype); + var unary = IR.F.CPU.Boxing(IR.F.Math.Abs(input), new DistributedType(ttype, new[] { SBP.B }, new(new[] { 1 }, "b"))); + var main = new Function("main", IR.F.CPU.Boxing(IR.F.Math.Abs(unary), ttype), input); Assert.True(CompilerServices.InferenceType(main)); @@ -177,7 +179,7 @@ public async Task TestLineDiffModuleI2C() var post_number = tv.CountCallFusion(); var post_result = CompilerServices.Evaluate(((Function)module.Entry!).Body, feed_dict); - Assert.Equal(1, pre_number); + Assert.Equal(2, pre_number); Assert.Equal(1, post_number); Assert.Equal(pre_result, post_result); } @@ -412,13 +414,14 @@ public async Task TestCircle2SameModule() [Fact] public async Task TestCircle2DiffModule() { - var input = new Var("input", new TensorType(DataTypes.Float32, new int[] { 1, 32, 32 })); - var v_0 = IR.F.CPU.Boxing(input, new DistributedType(input.CheckedTensorType, new[] { SBP.B }, new(new[] { 1 }, "t"))); - var v_1 = IR.F.Math.Cos(IR.F.CPU.Boxing(v_0, input.CheckedType)); + var ttype = new TensorType(DataTypes.Float32, new int[] { 1, 32, 32 }); + var input = new Var("input", ttype); + var v_0 = IR.F.CPU.Boxing(input, new DistributedType(ttype, new[] { SBP.B }, new(new[] { 1 }, "t"))); + var v_1 = IR.F.Math.Cos(IR.F.CPU.Boxing(v_0, ttype)); var v_2 = IR.F.Math.Sin(v_0); - var v_3 = IR.F.Math.Add(IR.F.CPU.Boxing(v_1, new DistributedType(input.CheckedTensorType, new[] { SBP.B }, new(new[] { 1 }, "t"))), v_2); + var v_3 = IR.F.Math.Add(IR.F.CPU.Boxing(v_1, new DistributedType(ttype, new[] { SBP.B }, new(new[] { 1 }, "t"))), v_2); var v_4 = IR.F.Math.Neg(v_3); - var main = new Function("main", v_4, new[] { input }); + var main = new Function("main", IR.F.CPU.Boxing(v_4, ttype), new[] { input }); Assert.True(CompilerServices.InferenceType(main)); @@ -446,7 +449,7 @@ public async Task TestCircle2DiffModule() Assert.Equal(3, pre_number); Assert.Equal(1, post_number); - Assert.Equal(pre_result, post_result); + Assert.True(Comparator.AllEqual(pre_result, post_result)); } [Fact] diff --git a/src/Nncase.Tests/packages.lock.json b/src/Nncase.Tests/packages.lock.json index 8d434dedb..6720a05e7 100644 --- a/src/Nncase.Tests/packages.lock.json +++ b/src/Nncase.Tests/packages.lock.json @@ -862,7 +862,8 @@ "dependencies": { "Nncase.Core": "[1.0.0, )", "Nncase.Evaluator": "[1.0.0, )", - "QuikGraph": "[2.5.0, )" + "QuikGraph": "[2.5.0, )", + "QuikGraph.Graphviz": "[2.5.0, )" } }, "nncase.importer": {