Skip to content

Commit

Permalink
transpose unpack tiling
Browse files Browse the repository at this point in the history
  • Loading branch information
zhen8838 committed Jan 22, 2025
1 parent 93e43e1 commit 0526e31
Show file tree
Hide file tree
Showing 22 changed files with 300 additions and 32 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/compiler-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ jobs:
<EnvironmentVariables>
<LD_LIBRARY_PATH>${{github.workspace}}/install/lib</LD_LIBRARY_PATH>
<DYLD_LIBRARY_PATH>${{github.workspace}}/install/lib</DYLD_LIBRARY_PATH>
<NNCASE_TILING_MAX_SOLUTIONS>1</NNCASE_TILING_MAX_SOLUTIONS>
</EnvironmentVariables>
</RunConfiguration>
</RunSettings>
Expand Down Expand Up @@ -248,6 +249,7 @@ jobs:
shell: bash
env:
NNCASE_COMPILER: ${{github.workspace}}/install/Nncase.Compiler.dll
NNCASE_TILING_MAX_SOLUTIONS: 1
run: |
dotnet tool install --global dotnet-coverage
dotnet-coverage collect -s tools/dotnet_coverage.settings.xml -f cobertura -o coverage/onnx_basic.xml pytest tests/importer/onnx_/basic/ --doctest-modules --junitxml=test_results/onnx_basic.xml
Expand Down
4 changes: 0 additions & 4 deletions modules/Nncase.Modules.CPU/CodeGen/CPU/CSourceCompiler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,7 @@ private string ArgumentsSpecific(string sourcePath, string outPath)
var archConfig = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ?
"-DCMAKE_C_COMPILER=clang-cl -DCMAKE_CXX_COMPILER=clang-cl" : string.Empty;

#if DEBUG
var config = "Debug";
#else
var config = "Release";
#endif
var script = $"""
cd {sourcePath} &&
cmake -E remove_directory build &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,17 +324,34 @@ protected override CSymbol VisitCall(Call expr)
IndentScope.Writer.Write(RazorTemplateEngine.RenderAsync("~/CodeGen/CPU/Templates/Kernels/Matmul.cshtml", new TypedKernelTemplateModel<TIR.CPU.Matmul>(matmul)
{
Arguments = arguments.Select(x => new KernelArgument { Symbol = x }).ToArray(),
Indent = string.Join(string.Empty, Enumerable.Repeat(' ', IndentScope.Writer.Indent)),
Indent = new string(' ', IndentScope.Writer.Indent),
}).Result);

break;
case TIR.CPU.Pack pack:
IndentScope.Writer.Write(RazorTemplateEngine.RenderAsync("~/CodeGen/CPU/Templates/Kernels/Pack.cshtml", new TypedKernelTemplateModel<TIR.CPU.Pack>(pack)
{
Arguments = arguments.Select(x => new KernelArgument { Symbol = x }).ToArray(),
Indent = string.Join(string.Empty, Enumerable.Repeat(' ', IndentScope.Writer.Indent)),
Indent = new string(' ', IndentScope.Writer.Indent),
}).Result);
break;
case TIR.CPU.Transpose transpose:
IndentScope.Writer.Write(RazorTemplateEngine.RenderAsync("~/CodeGen/CPU/Templates/Kernels/Transpose.cshtml", new TypedKernelTemplateModel<TIR.CPU.Transpose>(transpose)
{
Arguments = arguments.Select(x => new KernelArgument { Symbol = x }).ToArray(),
Indent = new string(' ', IndentScope.Writer.Indent),
}).Result);
break;
case TIR.CPU.Unpack unpack:
IndentScope.Writer.Write(RazorTemplateEngine.RenderAsync("~/CodeGen/CPU/Templates/Kernels/Unpack.cshtml", new TypedKernelTemplateModel<TIR.CPU.Unpack>(unpack)
{
Arguments = arguments.Select(x => new KernelArgument { Symbol = x }).ToArray(),
Indent = new string(' ', IndentScope.Writer.Indent),
}).Result);
break;
case TIR.CPU.Reduce reduce:
IndentScope.Writer.IndWrite($"reduce_{reduce.ReduceOp.ToC()}<fixed_shape<{string.Join(",", reduce.Axes)}>, fixed_shape<{string.Join(",", reduce.PackedAxes)}>, fixed_shape<{string.Join(",", reduce.PadedNums)}>>({arguments[0].Name}, {arguments[1].Name});\n");
break;
default:
throw new NotSupportedException();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
@model Nncase.CodeGen.CPU.TypedKernelTemplateModel<Nncase.TIR.CPU.Transpose>
@{
}
transpose<fixed_shape<@string.Join(",", Model.Target.Perm)>>(@Html.Raw(Model.Arguments[0].Symbol.Name), @Html.Raw(Model.Arguments[1].Symbol.Name));
@(Model.Indent)transpose<fixed_shape<@string.Join(",", Model.Target.Perm)>>(@Html.Raw(Model.Arguments[0].Symbol.Name), @Html.Raw(Model.Arguments[1].Symbol.Name));
22 changes: 21 additions & 1 deletion modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Reduce.cs
Original file line number Diff line number Diff line change
@@ -1,18 +1,38 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using Google.OrTools.ConstraintSolver;
using Nncase.Evaluator;
using Nncase.IR;
using Nncase.Schedule;
using Nncase.TIR.CPU;

namespace Nncase.Evaluator.TIR.CPU;

public sealed class ReduceEvaluator : ITypeInferencer<Reduce>
public sealed class ReduceEvaluator : ITypeInferencer<Reduce>, IKernelInfoEvaluator<Reduce>
{
public IRType Visit(ITypeInferenceContext context, Reduce target)
{
context.CheckArgumentType<TensorType>(target, Reduce.Input);
context.CheckArgumentType<TensorType>(target, Reduce.Output);
return TupleType.Void;
}

public MicroKernelInfo Visit(Reduce op, MicroKernelContext context)
{
var domain = context.AccessMaps[0].Domains;
var primitives = Enumerable.Repeat(1, domain.Length).ToArray();
var multipliers = Enumerable.Repeat(new ValueRange<int>(1, int.MaxValue), domain.Length).ToArray();
var bufferInfos = new MicroKernelBufferInfo[context.BufferShapes.Length];
var opt = (ICpuTargetOptions)context.TargetOptions;
bufferInfos[0] = new(opt.MemoryBandWidths[1], opt.MemoryBandWidths[1], MicroKernelBufferInfo.BufferState.Read);
bufferInfos[1] = new(opt.MemoryBandWidths[1], opt.MemoryBandWidths[1], MicroKernelBufferInfo.BufferState.Write);
return new MicroKernelInfo(primitives, multipliers, bufferInfos, GetComputeCycle);
}

private static IntExpr GetComputeCycle(IntExpr[][] bufferShapes, Solver solver, MicroKernelContext context)
{
var factor = System.Math.Min(context.BufferShapes[0][^1], 32);
return factor * (1 + solver.MakeIsLessVar(bufferShapes[0][^1], solver.MakeIntConst(factor)));
}
}
22 changes: 21 additions & 1 deletion modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Transpose.cs
Original file line number Diff line number Diff line change
@@ -1,18 +1,38 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using Google.OrTools.ConstraintSolver;
using Nncase.Evaluator;
using Nncase.IR;
using Nncase.Schedule;
using Nncase.TIR.CPU;

namespace Nncase.Evaluator.TIR.CPU;

public sealed class TransposeEvaluator : ITypeInferencer<Transpose>
public sealed class TransposeEvaluator : ITypeInferencer<Transpose>, IKernelInfoEvaluator<Transpose>
{
public IRType Visit(ITypeInferenceContext context, Transpose target)
{
context.CheckArgumentType<TensorType>(target, Transpose.Input);
context.CheckArgumentType<TensorType>(target, Transpose.Output);
return TupleType.Void;
}

public MicroKernelInfo Visit(Transpose op, MicroKernelContext context)
{
var domain = context.AccessMaps[0].Domains;
var primitives = Enumerable.Repeat(1, domain.Length).ToArray();
var multipliers = Enumerable.Repeat(new ValueRange<int>(1, int.MaxValue), domain.Length).ToArray();
var bufferInfos = new MicroKernelBufferInfo[context.BufferShapes.Length];
var opt = (ICpuTargetOptions)context.TargetOptions;
bufferInfos[0] = new(opt.MemoryBandWidths[1], opt.MemoryBandWidths[1], MicroKernelBufferInfo.BufferState.Read);
bufferInfos[1] = new(opt.MemoryBandWidths[1], opt.MemoryBandWidths[1], MicroKernelBufferInfo.BufferState.Write);
return new MicroKernelInfo(primitives, multipliers, bufferInfos, GetComputeCycle);
}

private static IntExpr GetComputeCycle(IntExpr[][] bufferShapes, Solver solver, MicroKernelContext context)
{
var factor = System.Math.Min(context.BufferShapes[0][^1], 32);
return factor * (1 + solver.MakeIsLessVar(bufferShapes[0][^1], solver.MakeIntConst(factor)));
}
}
22 changes: 21 additions & 1 deletion modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Unpack.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,36 @@
using System.Diagnostics;
using System.Linq;
using DryIoc.ImTools;
using Google.OrTools.ConstraintSolver;
using Nncase.CostModel;
using Nncase.IR;
using Nncase.Schedule;
using Nncase.TIR.CPU;
using Nncase.Utilities;
using OrtKISharp;

namespace Nncase.Evaluator.TIR.CPU;

public sealed class UnpackEvaluator : ITypeInferencer<Unpack>
public sealed class UnpackEvaluator : ITypeInferencer<Unpack>, IKernelInfoEvaluator<Unpack>
{
/// <inheritdoc/>
public IRType Visit(ITypeInferenceContext context, Unpack target) => TupleType.Void;

public MicroKernelInfo Visit(Unpack op, MicroKernelContext context)
{
var domain = context.AccessMaps[0].Domains;
var primitives = Enumerable.Repeat(1, domain.Length).ToArray();
var multipliers = Enumerable.Repeat(new ValueRange<int>(1, int.MaxValue), domain.Length).ToArray();
var bufferInfos = new MicroKernelBufferInfo[context.BufferShapes.Length];
var opt = (ICpuTargetOptions)context.TargetOptions;
bufferInfos[0] = new(opt.MemoryBandWidths[1], opt.MemoryBandWidths[1], MicroKernelBufferInfo.BufferState.Read);
bufferInfos[1] = new(opt.MemoryBandWidths[1], opt.MemoryBandWidths[1], MicroKernelBufferInfo.BufferState.Write);
return new MicroKernelInfo(primitives, multipliers, bufferInfos, GetComputeCycle);
}

private static IntExpr GetComputeCycle(IntExpr[][] bufferShapes, Solver solver, MicroKernelContext context)
{
var factor = System.Math.Min(context.BufferShapes[0][^1], 32);
return factor * (1 + solver.MakeIsLessVar(bufferShapes[0][^1], solver.MakeIntConst(factor)));
}
}
6 changes: 3 additions & 3 deletions modules/Nncase.Modules.CPU/Passes/CPUFunctionPartitionPass.cs
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,15 @@ bool CheckField(Expr f)
}

// 3. reconstruction
var constructor = new DistributedReConstructor(funcName, ModuleKind, condenseAlgo);
var constructor = new DistributedReconstructor(funcName, ModuleKind, condenseAlgo);
var post = constructor.Construct();
return post;
}
}

internal sealed class DistributedReConstructor : ExprReConstructor<ExprVertex, ExprEdge>
internal sealed class DistributedReconstructor : ExprReconstructor<ExprVertex, ExprEdge>
{
public DistributedReConstructor(string funcName, string moduleKind, CondensationGraphAlgorithm<ExprVertex, ExprEdge> algo)
public DistributedReconstructor(string funcName, string moduleKind, CondensationGraphAlgorithm<ExprVertex, ExprEdge> algo)
: base(algo)
{
FuncName = funcName;
Expand Down
68 changes: 68 additions & 0 deletions modules/Nncase.Modules.CPU/Passes/Rules/CPU/Affine/LowerReduce.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using Nncase.IR;
using Nncase.IR.Affine;
using Nncase.PatternMatch;
using Nncase.Targets;
using static Nncase.IR.TypePatternUtility;
using static Nncase.PatternMatch.Utility;

namespace Nncase.Passes.Rules.CPU.Affine;

[RuleGenerator]
public partial class LowerReduce : RewriteRule<Pattern>
{
public LowerReduce(string moduleKind = CPUTarget.Kind)
{
ModuleKind = moduleKind;
}

public string ModuleKind { get; }

public override Pattern Pattern { get; } = IsCall(
"call",
IsOp<IR.CPU.PackedReduce>("op"),
IsWildcard("input") with { TypePattern = HasShape(s => s.Rank > 0 && s.IsFixed, "tileable") });

private Expr? GetReplace(Expr call, IR.CPU.PackedReduce op, Expr input)
{
var inputShape = input.CheckedShape.ToValueArray();
var rank = inputShape.Length;
var domains = IR.F.Affine.Domains(rank);
var outrank = call.CheckedShape.Rank;
var results = new AffineRange[outrank];
{
var j = 0;
for (int i = 0; i < rank; i++)
{
if (op.Axes.Contains(i))
{
if (op.KeepDims == true)
{
results[j++] = new AffineRange(0, 1);
}
}
else
{
results[j++] = new AffineRange(domains[i].Offset, domains[i].Extent);
}
}
}

var affinemap = new AffineMap(domains, default, results);
var outBuffer = call.CheckedType switch
{
TensorType t => IR.F.Buffer.Uninitialized(t.DType, TIR.MemoryLocation.Data, t.Shape.ToValueArray()),
DistributedType dt => IR.F.Buffer.Uninitialized(dt.TensorType.DType, TIR.MemoryLocation.Data, dt.TensorType.Shape.ToValueArray(), dt.NdSBP, dt.Placement),
_ => throw new ArgumentOutOfRangeException(nameof(call)),
};

return IR.F.Affine.Grid(ModuleKind)
.Domain(rank, out var _)
.Read(input, AffineMap.Identity(rank), out var intile)
.Write(outBuffer, affinemap, out var outTile)
.Body(TIR.F.CPU.Reduce(intile, outTile, op.PackedAxes.ToArray(), op.PadedNums.ToArray(), op.Axes, op.KeepDims, op.ReduceOp))
.Build();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using Nncase.IR;
using Nncase.IR.Affine;
using Nncase.IR.Math;
using Nncase.PatternMatch;
using Nncase.Targets;
using static Nncase.IR.TypePatternUtility;
using static Nncase.PatternMatch.Utility;

namespace Nncase.Passes.Rules.CPU.Affine;

[RuleGenerator]
public partial class LowerTranspose : RewriteRule<Pattern>
{
public LowerTranspose(string moduleKind = CPUTarget.Kind)
{
ModuleKind = moduleKind;
}

public string ModuleKind { get; }

public override Pattern Pattern { get; } = PatternMatch.F.Tensors.IsTranspose("trans", "call", IsWildcard("input"), IsTensorConst("perm"))
with
{ TypePattern = HasShape(s => s.Rank > 0 && s.IsFixed, "tileable") };

private Expr? GetReplace(Expr call, Expr input, int[] perm)
{
var inputShape = input.CheckedShape.ToValueArray();
var rank = inputShape.Length;
var domains = IR.F.Affine.Domains(rank);
var results = new AffineRange[rank];
for (int i = 0; i < rank; i++)
{
results[perm[i]] = new AffineRange(domains[i].Offset, domains[i].Extent);
}

var inputAccessMap = new AffineMap(domains, default, results);
var outBuffer = call.CheckedType switch
{
TensorType t => IR.F.Buffer.Uninitialized(t.DType, TIR.MemoryLocation.Data, t.Shape.ToValueArray()),
DistributedType dt => IR.F.Buffer.Uninitialized(dt.TensorType.DType, TIR.MemoryLocation.Data, dt.TensorType.Shape.ToValueArray(), dt.NdSBP, dt.Placement),
_ => throw new ArgumentOutOfRangeException(nameof(call)),
};

return IR.F.Affine.Grid(ModuleKind)
.Domain(rank, out var _)
.Read(input, inputAccessMap, out var intile)
.Write(outBuffer, AffineMap.Identity(rank), out var outTile)
.Body(TIR.F.CPU.Transpose(intile, outTile, perm))
.Build();
}
}
64 changes: 64 additions & 0 deletions modules/Nncase.Modules.CPU/Passes/Rules/CPU/Affine/LowerUnpack.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using Nncase.IR;
using Nncase.IR.Affine;
using Nncase.IR.Math;
using Nncase.PatternMatch;
using Nncase.Targets;
using static Nncase.IR.TypePatternUtility;
using static Nncase.PatternMatch.Utility;

namespace Nncase.Passes.Rules.CPU.Affine;

[RuleGenerator]
public partial class LowerUnpack : RewriteRule<Pattern>
{
public LowerUnpack(string moduleKind = CPUTarget.Kind)
{
ModuleKind = moduleKind;
}

public string ModuleKind { get; }

public override Pattern Pattern { get; } = IsCall(
"call",
IsOp<IR.CPU.Unpack>("op"),
IsWildcard("input") with { TypePattern = HasShape(s => s.Rank > 0 && s.IsFixed, "tileable") });

private Expr? GetReplace(Expr call, IR.CPU.Unpack op, Expr input)
{
var inputShape = input.CheckedShape.ToValueArray();
var rank = inputShape.Length;
var domains = IR.F.Affine.Domains(rank);
var results = new AffineRange[rank];

for (int axis = 0; axis < rank; axis++)
{
// e.g. f32[128,256] -> f32<4>[32,256]
if (op.Axes.IndexOf(axis) is int i && i != -1)
{
results[axis] = new AffineRange(op.Lanes[i] * domains[axis].Offset, op.Lanes[i] * domains[axis].Extent);
}
else
{
results[axis] = new AffineRange(domains[axis].Offset, domains[axis].Extent);
}
}

var affinemap = new AffineMap(domains, default, results);
var outBuffer = call.CheckedType switch
{
TensorType t => IR.F.Buffer.Uninitialized(t.DType, TIR.MemoryLocation.Data, t.Shape.ToValueArray()),
DistributedType dt => IR.F.Buffer.Uninitialized(dt.TensorType.DType, TIR.MemoryLocation.Data, dt.TensorType.Shape.ToValueArray(), dt.NdSBP, dt.Placement),
_ => throw new ArgumentOutOfRangeException(nameof(call)),
};

return IR.F.Affine.Grid(ModuleKind)
.Domain(rank, out var _)
.Read(input, AffineMap.Identity(rank), out var intile)
.Write(outBuffer, affinemap, out var outTile)
.Body(TIR.F.CPU.Unpack(intile, outTile, op.Lanes, op.Axes))
.Build();
}
}
Loading

0 comments on commit 0526e31

Please sign in to comment.