Skip to content

Commit

Permalink
Refactor type infer
Browse files Browse the repository at this point in the history
  • Loading branch information
sunnycase committed Jan 14, 2025
1 parent 1bb40e5 commit 07932ae
Show file tree
Hide file tree
Showing 200 changed files with 1,659 additions and 1,309 deletions.
2 changes: 0 additions & 2 deletions Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@
<PackageVersion Include="Extension.Mathematics" Version="1.2.12" />
<PackageVersion Include="Fody" Version="6.8.1" />
<PackageVersion Include="GiGraph.Dot" Version="3.0.1" />
<PackageVersion Include="QuikGraph" Version="2.5.0" />
<PackageVersion Include="QuikGraph.Graphviz" Version="2.5.0" />
<PackageVersion Include="Google.Protobuf" Version="3.27.3" />
<PackageVersion Include="Grpc.Tools" Version="2.65.0" />
<PackageVersion Include="Humanizer.Core" Version="2.14.1" />
Expand Down
14 changes: 7 additions & 7 deletions modules/Nncase.Modules.CPU/Evaluator/CPU/Im2col.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ public IValue Visit(IEvaluateContext context, Im2col target)
{
var inputTensor = context.GetArgumentValueAsTensor(target, Im2col.Input);
var lanes = inputTensor.ElementType.SizeInBytes / 4;
int batch = inputTensor.Shape[0].FixedValue;
int inChannel = inputTensor.Shape[1].FixedValue;
int height = inputTensor.Shape[2].FixedValue;
int width = inputTensor.Shape[3].FixedValue;
int batch = (int)inputTensor.Shape[0].FixedValue;
int inChannel = (int)inputTensor.Shape[1].FixedValue;
int height = (int)inputTensor.Shape[2].FixedValue;
int width = (int)inputTensor.Shape[3].FixedValue;
int pad_h_before = target.Padding[0];
int pad_h_after = target.Padding[1];
int pad_w_before = target.Padding[2];
Expand Down Expand Up @@ -96,7 +96,7 @@ public IValue Visit(IEvaluateContext context, Im2col target)
}
}

return Value.FromTensor(Tensor.FromBytes(inputTensor.ElementType, System.Runtime.InteropServices.MemoryMarshal.Cast<float, byte>(outputTensor).ToArray(), new[] { inChannel * kernel_h * kernel_w, batch * output_h * output_w }));
return Value.FromTensor(Tensor.FromBytes(inputTensor.ElementType, System.Runtime.InteropServices.MemoryMarshal.Cast<float, byte>(outputTensor).ToArray(), [inChannel * kernel_h * kernel_w, batch * output_h * output_w]));
}

private IRType Visit(DistributedType dt, Im2col target)
Expand Down Expand Up @@ -143,8 +143,8 @@ private IRType Visit(DistributedType dt, Im2col target)

private IRType Visit(TensorType tt, Im2col target)
{
int height = tt.Shape[2].FixedValue;
int width = tt.Shape[3].FixedValue;
int height = (int)tt.Shape[2].FixedValue;
int width = (int)tt.Shape[3].FixedValue;
int pad_h_before = target.Padding[0];
int pad_h_after = target.Padding[1];
int pad_w_before = target.Padding[2];
Expand Down
2 changes: 1 addition & 1 deletion modules/Nncase.Modules.CPU/Evaluator/CPU/Pack.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public IValue Visit(IEvaluateContext context, Pack target)

var output = Cast(inputOrt.ToTensor(), context.CurrentCall.Arguments[Pack.Input.Index].CheckedDataType).Evaluate().AsTensor();

return Value.FromTensor(Tensor.FromBytes(new VectorType(output.ElementType, target.Lanes), output.BytesBuffer.ToArray(), inputOrt.Shape.SkipLast(target.Lanes.Count).Select(i => (int)i).ToArray()));
return Value.FromTensor(Tensor.FromBytes(new VectorType(output.ElementType, target.Lanes), output.BytesBuffer.ToArray(), inputOrt.Shape.SkipLast(target.Lanes.Count).Select(i => i).ToArray()));
}
else
{
Expand Down
4 changes: 2 additions & 2 deletions modules/Nncase.Modules.CPU/Evaluator/CPU/PackedBinary.cs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ public Cost Visit(ICostEvaluateContext context, PackedBinary target)
private IRType Visit(PackedBinary target, TensorType a, TensorType b)
{
var rank = System.Math.Max(a.Shape.Rank, b.Shape.Rank);
var outShape = new int[rank];
var outShape = new long[rank];
var lhsOrginShape = a.Shape.ToValueArray();
var rhsOrginShape = b.Shape.ToValueArray();
for (int i = 0; i < target.LhsPackedAxes.Count; i++)
Expand Down Expand Up @@ -127,7 +127,7 @@ private IRType Visit(PackedBinary target, TensorType a, TensorType b)
case ( >= 0, >= 0):
switch (lhsOrginShape[aAxis], rhsOrginShape[bAxis])
{
case (int l, int r) when l == r:
case (long l, long r) when l == r:
outShape[rank + i] = a.Shape[aAxis].FixedValue;
orginKinds[rank + i] = DimKind.E;
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public IValue Visit(IEvaluateContext context, PackedLayerNorm target)
var outputTensor = OrtKISharp.Tensor.MakeTensor(new Memory<float>(output), OrtDataType.Float, unpackedInput.Shape);
outputTensor = CPUEvaluatorUtility.RepackTensor(outputTensor, lanes, target.PackedAxes, target.PadedNums);

return Value.FromTensor(Tensor.FromBytes(new VectorType(DataTypes.Float32, lanes), outputTensor.BytesBuffer.ToArray(), outputTensor.Shape.SkipLast(target.PackedAxes.Count).Select(i => (int)i).ToArray()));
return Value.FromTensor(Tensor.FromBytes(new VectorType(DataTypes.Float32, lanes), outputTensor.BytesBuffer.ToArray(), outputTensor.Shape.SkipLast(target.PackedAxes.Count).Select(i => i).ToArray()));
}

/// <inheritdoc/>
Expand Down
16 changes: 8 additions & 8 deletions modules/Nncase.Modules.CPU/Evaluator/CPU/PackedMatMul.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,20 @@ public IValue Visit(IEvaluateContext context, PackedMatMul target)

var outRank = context.CurrentCall.CheckedShape.Rank;
var outLanes = Array.Empty<int>();
var outShape = Array.Empty<int>();
var outShape = Array.Empty<long>();
var axes = Array.Empty<int>();
var (lm, lk) = target.TransposeA ? (lhs.Rank - target.RhsPackedAxes.Count - 1, lhs.Rank - target.RhsPackedAxes.Count - 2) : (lhs.Rank - target.LhsPackedAxes.Count - 2, lhs.Rank - target.LhsPackedAxes.Count - 1);
var (rk, rn) = target.TransposeB ? (rhs.Rank - target.RhsPackedAxes.Count - 1, rhs.Rank - target.RhsPackedAxes.Count - 2) : (rhs.Rank - target.RhsPackedAxes.Count - 2, rhs.Rank - target.RhsPackedAxes.Count - 1);
if (target.LhsPackedAxes.Count == 0 && target.RhsPackedAxes.Count == 1)
{
outLanes = new[] { (int)rhs.Shape[^1] };
outShape = new[] { (int)lhs.Shape[lm], (int)rhs.Shape[rn] };
outShape = new[] { lhs.Shape[lm], rhs.Shape[rn] };
axes = new[] { outRank - 1 };
}
else if (target.LhsPackedAxes.Count == 1 && target.RhsPackedAxes.Count == 0)
{
outLanes = new[] { (int)lhs.Shape[^1] };
outShape = new[] { (int)lhs.Shape[lm], (int)rhs.Shape[rn] };
outShape = new[] { lhs.Shape[lm], rhs.Shape[rn] };
axes = new[] { outRank - 2 };
}
else if (target.LhsPackedAxes.Count == 1 && target.RhsPackedAxes.Count == 1)
Expand All @@ -51,24 +51,24 @@ public IValue Visit(IEvaluateContext context, PackedMatMul target)
axes = new[] { outRank - 2, outRank - 1 };
}

outShape = new[] { (int)lhs.Shape[lm], (int)rhs.Shape[rn] };
outShape = new[] { lhs.Shape[lm], rhs.Shape[rn] };
}
else if (target.LhsPackedAxes.Count == 1 && target.RhsPackedAxes.Count == 2)
{
outLanes = new[] { (int)rhs.Shape[^1] };
outShape = new[] { (int)lhs.Shape[lm], (int)rhs.Shape[rn] };
outShape = new[] { lhs.Shape[lm], rhs.Shape[rn] };
axes = new[] { outRank - 1 };
}
else if (target.LhsPackedAxes.Count == 2 && target.RhsPackedAxes.Count == 1)
{
outLanes = new[] { (int)lhs.Shape[^2] };
outShape = new[] { (int)lhs.Shape[lm], (int)rhs.Shape[rn] };
outShape = new[] { lhs.Shape[lm], rhs.Shape[rn] };
axes = new[] { outRank - 2 };
}
else if (target.LhsPackedAxes.Count == 2 && target.RhsPackedAxes.Count == 2)
{
outLanes = new[] { (int)lhs.Shape[^2], (int)rhs.Shape[^1] };
outShape = new[] { (int)lhs.Shape[lm], (int)rhs.Shape[rn] };
outShape = new[] { lhs.Shape[lm], rhs.Shape[rn] };
axes = new[] { outRank - 2, outRank - 1 };
}
else
Expand All @@ -79,7 +79,7 @@ public IValue Visit(IEvaluateContext context, PackedMatMul target)
var maxRank = System.Math.Max(lhs.Shape.Length - target.LhsPackedAxes.Count, rhs.Shape.Length - target.RhsPackedAxes.Count);
outShape = Enumerable.Repeat(1L, maxRank - lhs.Shape.Length + target.LhsPackedAxes.Count).Concat(lhs.Shape.SkipLast(2 + target.LhsPackedAxes.Count)).
Zip(Enumerable.Repeat(1L, maxRank - rhs.Shape.Length + target.RhsPackedAxes.Count).Concat(rhs.Shape.SkipLast(2 + target.RhsPackedAxes.Count))).
Select(p => (int)System.Math.Max(p.First, p.Second)).
Select(p => System.Math.Max(p.First, p.Second)).
Concat(outShape).ToArray();

foreach (var axis in target.LhsPackedAxes.Reverse())
Expand Down
2 changes: 1 addition & 1 deletion modules/Nncase.Modules.CPU/Evaluator/CPU/PackedReduce.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public sealed class PackedReduceEvaluator : IEvaluator<PackedReduce>, ITypeInfer
public IValue Visit(IEvaluateContext context, PackedReduce target)
{
var input = context.GetOrtArgumentValue(target, PackedReduce.Input);
var inshape = input.Shape.SkipLast(target.PackedAxes.Count).Select(i => (int)i).ToArray();
var inshape = input.Shape.SkipLast(target.PackedAxes.Count).Select(i => i).ToArray();
var inlanes = input.Shape.TakeLast(target.PackedAxes.Count).Select(i => (int)i).ToArray();
var unpackedInput = CPUEvaluatorUtility.UnpackTensor(input, target.PackedAxes, target.PadedNums, out _);
var axes = target.Axes.Select(i => (long)i).ToArray();
Expand Down
2 changes: 1 addition & 1 deletion modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Binary.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public MicroKernelInfo Visit(Binary op, MicroKernelContext context)
{
var domain = context.AccessMaps[0].Domains;
var primitives = Enumerable.Repeat(1, domain.Length).ToArray();
var multipliers = Enumerable.Repeat(new ValueRange<int>(1, int.MaxValue), domain.Length).ToArray();
var multipliers = Enumerable.Repeat(new ValueRange<long>(1, int.MaxValue), domain.Length).ToArray();
var bufferInfos = new MicroKernelBufferInfo[context.BufferShapes.Length];
var opt = (ICpuTargetOptions)context.TargetOptions;
bufferInfos[0] = new(opt.MemoryBandWidths[1], opt.MemoryBandWidths[1], MicroKernelBufferInfo.BufferState.Read);
Expand Down
2 changes: 1 addition & 1 deletion modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Matmul.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ public MicroKernelInfo Visit(Matmul op, MicroKernelContext context)
{
var domain = context.AccessMaps[0].Domains;
var primitives = Enumerable.Repeat(1, domain.Length).ToArray();
var multipliers = Enumerable.Repeat(new ValueRange<int>(1, int.MaxValue), domain.Length).ToArray();
var multipliers = Enumerable.Repeat(new ValueRange<long>(1, int.MaxValue), domain.Length).ToArray();

var (k, m, n) = (context.BufferShapes[0][^1], context.BufferShapes[2][^2], context.BufferShapes[2][^1]);
var (lpack, rpack) = PackedMatMul.GetPackKind(op.LhsPackedAxes, op.RhsPackedAxes);
Expand Down
2 changes: 1 addition & 1 deletion modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Pack.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public MicroKernelInfo Visit(Nncase.TIR.CPU.Pack op, MicroKernelContext context)
{
var domain = context.AccessMaps[0].Domains;
var primitives = Enumerable.Repeat(1, domain.Length).ToArray();
var multipliers = Enumerable.Repeat(new ValueRange<int>(1, int.MaxValue), domain.Length).ToArray();
var multipliers = Enumerable.Repeat(new ValueRange<long>(1, int.MaxValue), domain.Length).ToArray();
var bufferInfos = new MicroKernelBufferInfo[context.BufferShapes.Length];
var opt = (ICpuTargetOptions)context.TargetOptions;
bufferInfos[0] = new(opt.MemoryBandWidths[1], opt.MemoryBandWidths[1], MicroKernelBufferInfo.BufferState.Read);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public MicroKernelInfo Visit(PackedBinary op, MicroKernelContext context)
{
var domain = context.AccessMaps[0].Domains;
var primitives = Enumerable.Repeat(1, domain.Length).ToArray();
var multipliers = Enumerable.Repeat(new ValueRange<int>(1, int.MaxValue), domain.Length).ToArray();
var multipliers = Enumerable.Repeat(new ValueRange<long>(1, int.MaxValue), domain.Length).ToArray();
var bufferInfos = new MicroKernelBufferInfo[context.BufferShapes.Length];
var opt = (ICpuTargetOptions)context.TargetOptions;
bufferInfos[0] = new(opt.MemoryBandWidths[1], opt.MemoryBandWidths[1], MicroKernelBufferInfo.BufferState.Read);
Expand Down
2 changes: 1 addition & 1 deletion modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Swish.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public MicroKernelInfo Visit(Swish swish, MicroKernelContext context)
{
var domain = context.AccessMaps[0].Domains;
var primitives = Enumerable.Repeat(1, domain.Length).ToArray();
var multipliers = Enumerable.Repeat(new ValueRange<int>(1, int.MaxValue), domain.Length).ToArray();
var multipliers = Enumerable.Repeat(new ValueRange<long>(1, int.MaxValue), domain.Length).ToArray();
var bufferInfos = new MicroKernelBufferInfo[context.BufferShapes.Length];
var opt = (ICpuTargetOptions)context.TargetOptions;
bufferInfos[0] = new(opt.MemoryBandWidths[1], opt.MemoryBandWidths[1], MicroKernelBufferInfo.BufferState.Read);
Expand Down
2 changes: 1 addition & 1 deletion modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Unary.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public MicroKernelInfo Visit(Unary op, MicroKernelContext context)
{
var domain = context.AccessMaps[0].Domains;
var primitives = Enumerable.Repeat(1, domain.Length).ToArray();
var multipliers = Enumerable.Repeat(new ValueRange<int>(1, int.MaxValue), domain.Length).ToArray();
var multipliers = Enumerable.Repeat(new ValueRange<long>(1, int.MaxValue), domain.Length).ToArray();
var bufferInfos = new MicroKernelBufferInfo[context.BufferShapes.Length];
var opt = (ICpuTargetOptions)context.TargetOptions;
bufferInfos[0] = new(opt.MemoryBandWidths[1], opt.MemoryBandWidths[1], MicroKernelBufferInfo.BufferState.Read);
Expand Down
2 changes: 1 addition & 1 deletion modules/Nncase.Modules.CPU/IR/CPU/PackedReduce.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public sealed partial class PackedReduce : Op

public IRArray<int> PadedNums { get; }

public static (int[] OutPackAxes, int[] OutPadNums, int[] OutLanes, int[] OutShape) ComputeOutputInfo(PackedReduce target, int[] inShape, int[] inLanes)
public static (int[] OutPackAxes, int[] OutPadNums, int[] OutLanes, long[] OutShape) ComputeOutputInfo(PackedReduce target, long[] inShape, int[] inLanes)
{
var packedAxes = target.PackedAxes.ToList();
var padedNums = target.PadedNums.ToList();
Expand Down
6 changes: 3 additions & 3 deletions modules/Nncase.Modules.CPU/Passes/CPUFunctionPartitionPass.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ protected override Task<IRModule> RunCoreAsync(IRModule module, RunPassContext c
var ctx = new GraphPartition.GraphContext();
var convertor = new GraphPartition.GraphConvertor(x => x switch
{
Call call => (call.Target is IR.CPU.Boxing || call.CheckedType is DistributedType) ? true : false,
IR.Tuple tp => tp.Fields.ToArray().Any(f => f is Call { Target: IR.CPU.Boxing } b && b.CheckedType is TensorType) ? false : true,
_ => throw new NotSupportedException(),
Call call => (call.Target is IR.CPU.Boxing || call.CheckedType is DistributedType) ? Compat.COMPATIBLE : Compat.INCOMPATIBLE,
IR.Tuple tp => tp.Fields.ToArray().Any(f => f is Call { Target: IR.CPU.Boxing } b && b.CheckedType is TensorType) ? Compat.INCOMPATIBLE : Compat.COMPATIBLE,
_ => Compat.INCOMPATIBLE,
});
convertor.Visit(pre.Body, ctx);

Expand Down
Loading

0 comments on commit 07932ae

Please sign in to comment.