Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto Fusion Stage 1 #1290

Merged
merged 19 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@

#include <nncase/ntt/runtime.h>
#include "topo_aware_runtime.h"
#include "../device.h"
@foreach(var (s,i) in Model.Options.MemoryCapacities.Select((s,i) => (s,i)).Skip(1).SkipLast(1)){
@:uint8_t L@(i)Data[@(s)];
@foreach(var (s,i) in Model.Options.MemoryCapacities.Select((s,i) => (s,i)).SkipLast(1)){
@:uint8_t L@(i+1)Data[@(s)];
}
#include "../device.h"
#include "kernel.h"

//alignas(@(Model.Alignment)) static thread_local uint8_t local_data[@(Model.DataSize)];
Expand Down
1 change: 1 addition & 0 deletions modules/Nncase.Modules.CPU/TIR/CPU/Unary.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

namespace Nncase.TIR.CPU;

[ParameterInPlace(0, 1)]
public sealed partial class Unary : CPUKernelOp
{
public static readonly ParameterInfo Input = new(typeof(Unary), 0, "input");
Expand Down
14 changes: 14 additions & 0 deletions src/Nncase.Core/IR/Op.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,20 @@ public enum ParameterKind : int
Attribute,
}

[AttributeUsage(System.AttributeTargets.Class, Inherited = false, AllowMultiple = true)]
public sealed class ParameterInPlaceAttribute : System.Attribute
{
public ParameterInPlaceAttribute(int sourceIndex, int destIndex)
{
SourceIndex = sourceIndex;
DestIndex = destIndex;
}

public int SourceIndex { get; }

public int DestIndex { get; }
}

/// <summary>
/// Parameter information.
/// </summary>
Expand Down
9 changes: 9 additions & 0 deletions src/Nncase.Core/TIR/TIRExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@ namespace Nncase.TIR;
/// </summary>
public static class TIRExtensions
{
/// <summary>
/// Get the tir op buffer allocation reuse information.
/// </summary>
/// <returns> map dest index to source index. </returns>
public static Dictionary<int, int> GetInPlaceMemo(this Op op)
{
return op.GetType().GetCustomAttributes(typeof(ParameterInPlaceAttribute), true).OfType<ParameterInPlaceAttribute>().ToDictionary(a => a.DestIndex, a => a.SourceIndex);
}

/// <summary>
/// convert IEnumerable to tir Sequential.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ protected override void InternalCompute()

var dfs = new DepthFirstSearchAlgorithm<TVertex, TEdge>(this, VisitedGraph, new Dictionary<TVertex, GraphColor>(VisitedGraph.VertexCount));
dfs.TreeEdge += TreeEdge;
dfs.ForwardOrCrossEdge += TreeEdge;
dfs.Compute();
}

Expand Down
380 changes: 237 additions & 143 deletions src/Nncase.Schedule/Schedule/GraphTiler.cs

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Generic;

namespace Nncase.Schedule.MonteCarloTreeSearch;

public interface IEnvironmentState<TAction>
where TAction : class
{
int LegalActions();

TAction GetNextAction(int index);

IEnvironmentState<TAction>? PerformAction(TAction action);

double RollOut();
}
44 changes: 44 additions & 0 deletions src/Nncase.Schedule/Schedule/MonteCarloTreeSearch/SearchNode.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Generic;

namespace Nncase.Schedule.MonteCarloTreeSearch;

public abstract class SearchNode<T>
where T : class
{
public SearchNode(IEnvironmentState<T> state)
{
Parent = null;
Children = new List<SearchNode<T>>();
VisitTimes = 0;
QualityValue = 0.0;
State = state;
}

public SearchNode(SearchNode<T> parent, IEnvironmentState<T> state)
{
Parent = parent;
Children = new List<SearchNode<T>>();
VisitTimes = 0;
QualityValue = 0.0;
State = state;
Parent.Children.Add(this);
}

public SearchNode<T>? Parent { get; }

public List<SearchNode<T>> Children { get; }

public int VisitTimes { get; set; }

public double QualityValue { get; set; }

public IEnvironmentState<T> State { get; }

public bool IsRootNode => Parent is null;

public abstract void Update(double reward);
}
47 changes: 47 additions & 0 deletions src/Nncase.Schedule/Schedule/MonteCarloTreeSearch/Searcher.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Generic;

namespace Nncase.Schedule.MonteCarloTreeSearch;

public abstract class Searcher<T>
where T : class
{
public Searcher(int searchTimes = 20)
{
SearchTimes = searchTimes;
}

public int SearchTimes { get; }

public void Search(SearchNode<T> rootNode)
{
for (int i = 0; i < SearchTimes; i++)
{
if (!Selection(rootNode, out var node))
{
return;
}

var expanded = Expand(node);
if (expanded is not null)
{
BackPropagate(expanded, Simulation(expanded));
}
else
{
BackPropagate(node, double.PositiveInfinity);
}
}
}

public abstract bool Selection(SearchNode<T> node, out SearchNode<T> selected);

public abstract SearchNode<T>? Expand(SearchNode<T> node);

public abstract double Simulation(SearchNode<T> node);

public abstract void BackPropagate(SearchNode<T> node, double reward);
}
5 changes: 4 additions & 1 deletion src/Nncase.Schedule/Schedule/OrToolsExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ public static long[] Value(this Assignment sol, IntExpr[] inputs)
var vec = new long[inputs.Length];
for (int i = 0; i < inputs.Length; i++)
{
vec[i] = sol.Value(inputs[i].Var());
if (inputs[i] is not null)
{
vec[i] = sol.Value(inputs[i].Var());
}
}

return vec;
Expand Down
22 changes: 18 additions & 4 deletions src/Nncase.Schedule/Schedule/TileGraph/BufferizationAlgorithm.cs
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ private void Visit(TieredTileGraph rootGraph)
{
if (!BufferGraphMemo.TryGetValue(rootGraph, out _))
{
var wrappedGraph = new AdjacencyGraph<BufferIdentity, EquatableTaggedEdge<BufferIdentity, BufferEdgeKind>>();
var wrappedGraph = new AdjacencyGraph<BufferIdentity, EquatableTaggedEdge<BufferIdentity, BufferEdgeKind>>(allowParallelEdges: false);
var rootBufferGraph = new BufferGraph(rootGraph.Level, wrappedGraph);
Visit(rootGraph, rootBufferGraph);
Visit(rootGraph, rootBufferGraph, rootGraph);
foreach (var edge in rootGraph.Edges)
{
var source = new BufferIdentity(edge.Source, edge.Source.ReadAccesses.Length);
Expand All @@ -77,12 +77,14 @@ private void Visit(TieredTileGraph rootGraph)
}
}

private void Visit(TieredTileGraph graph, BufferGraph bufferGraph)
private HashSet<TileGrid> Visit(TieredTileGraph graph, BufferGraph bufferGraph, TieredTileGraph rootGraph)
{
var opnodes = new HashSet<TileGrid>();
if (graph.ClustersCount == 0)
{
foreach (var item in graph.Vertices)
{
opnodes.Add(item);
var outBid = new BufferIdentity(item, item.ReadAccesses.Length);
for (int i = 0; i < item.ReadAccesses.Length; i++)
{
Expand All @@ -97,10 +99,22 @@ private void Visit(TieredTileGraph graph, BufferGraph bufferGraph)
if (!BufferGraphMemo.TryGetValue(graph, out _))
{
var childBufferGraph = bufferGraph.CreateCluster<BufferGraph>(childGraph.Level, childGraph.OpId);
Visit(childGraph, childBufferGraph);
opnodes.UnionWith(Visit(childGraph, childBufferGraph, rootGraph));
BufferGraphMemo.Add(childGraph, childBufferGraph);
}
}

foreach (var edge in rootGraph.Edges)
{
if (opnodes.Contains(edge.Source) && opnodes.Contains(edge.Target))
{
var source = new BufferIdentity(edge.Source, edge.Source.ReadAccesses.Length);
var target = new BufferIdentity(edge.Target, edge.Tag);
bufferGraph.AddEdge(new(source, target, BufferEdgeKind.Outer));
}
}
}

return opnodes;
}
}
31 changes: 31 additions & 0 deletions src/Nncase.Schedule/Schedule/TileGraph/GraphExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,37 @@ public static bool Merge(this TieredTileGraph graph, MergePoint mergePoint)
return merger.Visit(graph);
}

public static List<MergePoint> GetMergePoints(this TieredTileGraph graph)
{
var mergePoints = new List<MergePoint>();
if (graph.Level != -1)
{
throw new InvalidOperationException("only can merge at top level!");
}

var children = graph.Clusters.OfType<TieredTileGraph>().ToArray();
foreach (var producer in children)
{
foreach (var comsumer in children)
{
if (ReferenceEquals(producer, comsumer))
{
continue;
}

foreach (var edge in graph.Edges)
{
if (comsumer.ContainsVertex(edge.Source) && producer.ContainsVertex(edge.Target))
{
mergePoints.Add(new(edge.Target, edge.Source, producer.Level));
}
}
}
}

return mergePoints;
}

public static void Walk(this TieredTileGraph graph, Action<ITileable> func, bool postOrder = false)
{
if (!postOrder)
Expand Down
Loading
Loading