From 07932ae9225bc40dd5fcbead84470d4d68478d0e Mon Sep 17 00:00:00 2001 From: sunnycase Date: Tue, 14 Jan 2025 01:34:05 +0000 Subject: [PATCH] Refactor type infer --- Directory.Packages.props | 2 - .../Evaluator/CPU/Im2col.cs | 14 +- .../Nncase.Modules.CPU/Evaluator/CPU/Pack.cs | 2 +- .../Evaluator/CPU/PackedBinary.cs | 4 +- .../Evaluator/CPU/PackedLayerNorm.cs | 2 +- .../Evaluator/CPU/PackedMatMul.cs | 16 +- .../Evaluator/CPU/PackedReduce.cs | 2 +- .../Evaluator/TIR/CPU/Binary.cs | 2 +- .../Evaluator/TIR/CPU/Matmul.cs | 2 +- .../Evaluator/TIR/CPU/Pack.cs | 2 +- .../Evaluator/TIR/CPU/PackedBinary.cs | 2 +- .../Evaluator/TIR/CPU/Swish.cs | 2 +- .../Evaluator/TIR/CPU/Unary.cs | 2 +- .../Nncase.Modules.CPU/IR/CPU/PackedReduce.cs | 2 +- .../Passes/CPUFunctionPartitionPass.cs | 6 +- .../Passes/Distributed/AutoDistributed.cs | 75 ++++----- .../Passes/Distributed/CustomOpScheme.cs | 2 +- .../Passes/Rules/CPU/Affine/LowerBinary.cs | 4 +- .../Passes/Rules/CPU/Affine/LowerMatMul.cs | 2 +- .../Passes/Rules/CPU/GraphPartition.cs | 2 +- .../Passes/Rules/CPU/PackRule.cs | 12 +- .../Passes/Tile/DeviceToTIRVisitor.cs | 8 +- .../Passes/Tile/FusionChecker.cs | 44 +++--- .../Passes/Tile/PrimTileVisitor.cs | 16 +- .../Targets/CPUTargetOptions.cs | 2 +- .../Utilities/PackUtility.cs | 18 +-- .../CodeGen/StackVM/CodegenVisitor.cs | 6 +- .../src/kernels/stackvm/optimized/matmul.cpp | 8 +- .../src/kernels/stackvm/optimized/opt_ops.h | 6 +- .../kernels/stackvm/reference/layer_norm.cpp | 5 - src/Nncase.Cli/Program.cs | 2 +- src/Nncase.Core/CompilerServices.cs | 9 ++ src/Nncase.Core/DataType.cs | 2 +- .../Evaluator/ITypeInferenceContext.cs | 2 + src/Nncase.Core/Evaluator/Metric.cs | 2 +- src/Nncase.Core/IR/Const.cs | 2 +- src/Nncase.Core/IR/Dimension.cs | 124 +++++++++++---- src/Nncase.Core/IR/Expr.cs | 56 +------ src/Nncase.Core/IR/ExprRewriter.cs | 16 +- src/Nncase.Core/IR/IRHelpers.cs | 9 +- src/Nncase.Core/IR/IRModule.cs | 8 +- src/Nncase.Core/IR/Math/Functional.cs | 7 + src/Nncase.Core/IR/NN/SpaceToBatch.cs | 2 +- src/Nncase.Core/IR/Shape.cs | 60 +++++++- src/Nncase.Core/IR/TypePattern.cs | 2 +- src/Nncase.Core/Passes/GetReplaceUtility.cs | 8 +- src/Nncase.Core/Passes/IRewriteProvider.cs | 5 + src/Nncase.Core/Schedule/MicroKernelInfo.cs | 4 +- src/Nncase.Core/Tensor.cs | 66 ++++---- src/Nncase.Core/TensorOfT.cs | 48 +++--- src/Nncase.Core/TensorUtilities.cs | 98 ++++++++---- .../Utilities/DistributedUtility.cs | 44 +++--- src/Nncase.Core/Utilities/DumpUtility.cs | 2 +- .../Diagnostics/ILDotPrintVisitor.cs | 91 +++++++++++ .../Diagnostics/ILPrintVisitor.cs | 21 ++- .../CostModel/EGraphCostModel.cs | 6 + .../CostModel/EGraphSatPrinter.cs | 18 ++- src/Nncase.EGraph/Passes/EGraphExtractor.cs | 37 ++++- src/Nncase.Evaluator/Buffers/BufferSubview.cs | 4 +- src/Nncase.Evaluator/Buffers/Uninitialized.cs | 4 +- .../Extension/OrtKIExtensions.cs | 8 +- src/Nncase.Evaluator/Math/MatMul.cs | 6 +- src/Nncase.Evaluator/NN/BatchToSpace.cs | 9 +- src/Nncase.Evaluator/NN/Conv2DTranspose.cs | 2 +- src/Nncase.Evaluator/NN/LayerNorm.cs | 2 +- src/Nncase.Evaluator/NN/SpaceToBatch.cs | 7 +- src/Nncase.Evaluator/RNN/LSTM.cs | 7 +- .../ShapeExpr/BroadcastShape.cs | 3 +- .../ShapeExpr/SqueezeShape.cs | 3 +- .../ShapeExpr/UnsqueezeShape.cs | 3 +- src/Nncase.Evaluator/Tensors/Broadcast.cs | 12 +- src/Nncase.Evaluator/Tensors/Concat.cs | 25 +-- .../Tensors/ConstantOfShape.cs | 16 +- src/Nncase.Evaluator/Tensors/Expand.cs | 13 +- src/Nncase.Evaluator/Tensors/Flatten.cs | 14 +- src/Nncase.Evaluator/Tensors/GatherND.cs | 2 +- src/Nncase.Evaluator/Tensors/GetItem.cs | 144 +++++++++++------- src/Nncase.Evaluator/Tensors/Range.cs | 4 +- src/Nncase.Evaluator/Tensors/Reshape.cs | 46 +++--- src/Nncase.Evaluator/Tensors/ScatterND.cs | 8 +- src/Nncase.Evaluator/Tensors/ShapeOf.cs | 28 +++- src/Nncase.Evaluator/Tensors/Slice.cs | 32 ++-- src/Nncase.Evaluator/Tensors/Split.cs | 9 +- src/Nncase.Evaluator/Tensors/Squeeze.cs | 3 +- src/Nncase.Evaluator/Tensors/Tile.cs | 19 ++- src/Nncase.Evaluator/Tensors/UnSqueeze.cs | 5 +- src/Nncase.Evaluator/Tensors/Where.cs | 2 +- src/Nncase.Evaluator/TypeInference.cs | 102 +++++++------ src/Nncase.Evaluator/TypeInferenceVisitor.cs | 7 +- src/Nncase.Importer/Ncnn/Convolution.cs | 6 +- src/Nncase.Importer/Ncnn/NcnnModelBin.cs | 6 +- src/Nncase.Importer/Ncnn/Pooling.cs | 8 +- src/Nncase.Importer/Onnx/DataGatter.cs | 2 +- src/Nncase.Importer/Onnx/MatMul.cs | 3 +- src/Nncase.Importer/Onnx/Pad.cs | 2 +- src/Nncase.Importer/Onnx/QLinearConv.cs | 10 +- src/Nncase.Importer/TFLite/Conv2D.cs | 8 +- src/Nncase.Importer/TFLite/Conv2DTranspose.cs | 4 +- src/Nncase.Importer/TFLite/ReduceWindow2D.cs | 4 +- src/Nncase.Importer/TFLite/TFLiteImporter.cs | 2 +- .../BufferSchedule/BufferScheduleTypes.cs | 6 +- .../BufferSchedule/LifeTimeCollector.cs | 8 +- .../GraphPartition/GraphConvetor.cs | 126 +++++++-------- src/Nncase.Passes/PassesModule.cs | 1 + .../Rules/Arithmetic/IdentityLaw.cs | 2 +- .../Rules/Neutral/AddPreProcess.cs | 10 +- .../Rules/Neutral/AddToConv2D.cs | 2 +- .../Rules/Neutral/BatchNormToBinary.cs | 2 +- .../Rules/Neutral/CombineReshape.cs | 10 +- .../Rules/Neutral/CombineTranspose.cs | 10 +- .../Rules/Neutral/DecomposeInstancenorm.cs | 4 +- .../Rules/Neutral/DecomposeLayernorm.cs | 4 +- .../Rules/Neutral/FoldConv2DAddMul.cs | 2 +- .../Rules/Neutral/FoldDilatedConv2D.cs | 18 +-- .../Rules/Neutral/FoldReshape.cs | 4 +- .../Rules/Neutral/MatMulToConv2D.cs | 20 +-- .../Rules/Neutral/RemoveUnusedFunctions.cs | 2 +- .../Rules/Neutral/RemoveUnusedVars.cs | 6 +- .../Rules/Neutral/ReshapeExpand.cs | 3 +- .../Rules/Neutral/ReshapeMatMul.cs | 4 +- .../Rules/Neutral/ScalarConstToTensor.cs | 2 +- .../Rules/Neutral/SpaceToBatchTransform.cs | 2 +- .../Rules/Neutral/SplitSpaceToBatch.cs | 6 +- .../Rules/Neutral/SqueezeShape.cs | 26 ++-- .../Rules/ShapeBucket/MergeBucketFusion.cs | 16 +- .../Rules/ShapeBucket/MergeCallToFusion.cs | 4 +- .../Rules/ShapeBucket/ShapeBucket.cs | 1 + .../Rules/ShapeBucket/ShapeBucketHelper.cs | 1 - .../Rules/ShapeBucket/SplitLLMStage.cs | 2 + .../Rules/ShapeExpr/FoldSplitShapeOf.cs | 2 +- .../Rules/WithMarker/CombineReshapePad.cs | 4 +- .../WithMarker/FoldConv2DBiasWithMarker.cs | 2 - .../WithMarker/MatMulToConv2DWithMarker.cs | 10 +- src/Nncase.Passes/RulesUtility.cs | 2 +- src/Nncase.Passes/SimplifyProvider.cs | 73 +++++++++ .../Quantization/CalibrationEvaluator.cs | 4 +- .../PytestCalibrationDatasetProvider.cs | 8 +- .../Quantization/QuantUtility.cs | 6 +- .../Quantization/Quantizer.Algorithms.cs | 1 - .../Quantization/Quantizer.cs | 2 +- src/Nncase.Schedule/Schedule/AffineTiler.cs | 4 +- .../Schedule/TileGraph/TileGraphTypes.cs | 6 +- .../Schedule/TileGraph/TileTreeTypes.cs | 4 +- .../TileGraph/TreeSolverPythonPrinter.cs | 2 +- .../Schedule/TileTree/TileTreeTypes.cs | 6 +- .../Schedule/TileTree/TreeCloner.cs | 2 +- .../TileTree/TreeSolverInitializer.cs | 2 +- src/Nncase.Schedule/Schedule/TilingSolver.cs | 4 +- .../Schedule/TilingUtilities.cs | 6 +- .../Transforms/AutoTilePass.cs | 6 +- .../Runtime/Interop/RTTensor.cs | 4 +- .../ViewModels/SimulateViewModel.cs | 2 +- .../TestingServices.cs | 4 +- .../TransformBase/Compare.cs | 20 +-- .../TransformBase/DataGenerator.cs | 14 +- src/Nncase.Tests/Affine/UnitTestFor.cs | 4 +- src/Nncase.Tests/Core/IR/UnitTestConst.cs | 2 +- src/Nncase.Tests/Core/IR/UnitTestDimension.cs | 10 +- src/Nncase.Tests/Core/IR/UnitTestShape.cs | 12 +- src/Nncase.Tests/Core/UnitTestDumpUtility.cs | 12 +- src/Nncase.Tests/Core/UnitTestExpression.cs | 10 +- .../Core/UnitTestGetReplaceUtility.cs | 4 +- src/Nncase.Tests/Core/UnitTestIValue.cs | 40 ++--- src/Nncase.Tests/Core/UnitTestTensor.cs | 28 ++-- src/Nncase.Tests/Core/UnitTestTensorHelper.cs | 10 +- .../Core/UnitTestTensorOfT.Helper.cs | 4 +- src/Nncase.Tests/Core/UnitTestTensorOfT.cs | 50 +++--- .../Core/UnitTestTensorUtilities.cs | 88 +++++------ src/Nncase.Tests/Core/UnitTestTypeInfer.cs | 11 +- .../Diagnostics/UnitTestDumpper.cs | 2 +- .../Distributed/UnitTestCustomOpScheme.cs | 2 +- .../Evaluator/UnitTestEvaluator.cs | 4 +- .../Evaluator/UnitTestEvaluatorMath.cs | 98 ++++++------ .../Evaluator/UnitTestEvaluatorNN.cs | 36 ++--- .../Evaluator/UnitTestEvaluatorTensors.cs | 28 ++-- .../Evaluator/UnitTestShapeEvaluator.cs | 28 ++-- .../Quant/UnitTestAddRangeOfMarker.cs | 2 +- .../Quant/UnitTestDumpQuantError.cs | 14 +- .../Quant/UnitTestExportQuantScheme.cs | 14 +- .../Quant/UnitTestImportQuantScheme.cs | 8 +- .../Quant/UnitTestQuantAlgorithm.cs | 12 +- src/Nncase.Tests/Rewrite/RewriteBase.cs | 38 ++--- .../Rewrite/UnitTestDataFlowRewrite.cs | 10 +- .../Rewrite/UnitTestEGraphRewrite.cs | 4 +- .../Rules/Neutral/UnitTestCombineBinary.cs | 10 +- .../Rules/Neutral/UnitTestCombineReshape.cs | 66 ++++---- .../Rules/Neutral/UnitTestCombineTranspose.cs | 98 ++++++------ .../Rules/Neutral/UnitTestFlattenToReshape.cs | 2 +- .../Rules/Neutral/UnitTestFoldLayerNorm.cs | 18 +-- .../Neutral/UnitTestReshapeBatchMatmul.cs | 2 +- .../Neutral/UnitTestSpaceToBatchTransform.cs | 2 +- .../Rules/Neutral/UnitTestSqueezeToReshape.cs | 2 +- .../Neutral/UnitTestUnSqueezeToReshape.cs | 2 +- .../Rules/Packing/PackUtilityTest.cs | 8 +- .../Rules/ShapeBucket/ShapeBucketTest.cs | 24 +-- .../ShapeExpr/UnitTestFoldGetItemShapeOf.cs | 6 +- src/Nncase.Tests/Simulator/UnitTestInterop.cs | 4 +- .../TIR/PrimFunc/IDataFlowPrimFuncCase.cs | 2 +- .../Targets/UnitTestCPUKernels.cs | 4 +- src/Nncase.Tests/Targets/UnitTestCPUTarget.cs | 2 +- 200 files changed, 1659 insertions(+), 1309 deletions(-) create mode 100644 src/Nncase.Passes/SimplifyProvider.cs diff --git a/Directory.Packages.props b/Directory.Packages.props index d6257d69f..c6a54caa0 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -28,8 +28,6 @@ - - diff --git a/modules/Nncase.Modules.CPU/Evaluator/CPU/Im2col.cs b/modules/Nncase.Modules.CPU/Evaluator/CPU/Im2col.cs index b838d8c2e..d62ac0e70 100644 --- a/modules/Nncase.Modules.CPU/Evaluator/CPU/Im2col.cs +++ b/modules/Nncase.Modules.CPU/Evaluator/CPU/Im2col.cs @@ -35,10 +35,10 @@ public IValue Visit(IEvaluateContext context, Im2col target) { var inputTensor = context.GetArgumentValueAsTensor(target, Im2col.Input); var lanes = inputTensor.ElementType.SizeInBytes / 4; - int batch = inputTensor.Shape[0].FixedValue; - int inChannel = inputTensor.Shape[1].FixedValue; - int height = inputTensor.Shape[2].FixedValue; - int width = inputTensor.Shape[3].FixedValue; + int batch = (int)inputTensor.Shape[0].FixedValue; + int inChannel = (int)inputTensor.Shape[1].FixedValue; + int height = (int)inputTensor.Shape[2].FixedValue; + int width = (int)inputTensor.Shape[3].FixedValue; int pad_h_before = target.Padding[0]; int pad_h_after = target.Padding[1]; int pad_w_before = target.Padding[2]; @@ -96,7 +96,7 @@ public IValue Visit(IEvaluateContext context, Im2col target) } } - return Value.FromTensor(Tensor.FromBytes(inputTensor.ElementType, System.Runtime.InteropServices.MemoryMarshal.Cast(outputTensor).ToArray(), new[] { inChannel * kernel_h * kernel_w, batch * output_h * output_w })); + return Value.FromTensor(Tensor.FromBytes(inputTensor.ElementType, System.Runtime.InteropServices.MemoryMarshal.Cast(outputTensor).ToArray(), [inChannel * kernel_h * kernel_w, batch * output_h * output_w])); } private IRType Visit(DistributedType dt, Im2col target) @@ -143,8 +143,8 @@ private IRType Visit(DistributedType dt, Im2col target) private IRType Visit(TensorType tt, Im2col target) { - int height = tt.Shape[2].FixedValue; - int width = tt.Shape[3].FixedValue; + int height = (int)tt.Shape[2].FixedValue; + int width = (int)tt.Shape[3].FixedValue; int pad_h_before = target.Padding[0]; int pad_h_after = target.Padding[1]; int pad_w_before = target.Padding[2]; diff --git a/modules/Nncase.Modules.CPU/Evaluator/CPU/Pack.cs b/modules/Nncase.Modules.CPU/Evaluator/CPU/Pack.cs index bef0440cf..9e9aa9fd1 100644 --- a/modules/Nncase.Modules.CPU/Evaluator/CPU/Pack.cs +++ b/modules/Nncase.Modules.CPU/Evaluator/CPU/Pack.cs @@ -35,7 +35,7 @@ public IValue Visit(IEvaluateContext context, Pack target) var output = Cast(inputOrt.ToTensor(), context.CurrentCall.Arguments[Pack.Input.Index].CheckedDataType).Evaluate().AsTensor(); - return Value.FromTensor(Tensor.FromBytes(new VectorType(output.ElementType, target.Lanes), output.BytesBuffer.ToArray(), inputOrt.Shape.SkipLast(target.Lanes.Count).Select(i => (int)i).ToArray())); + return Value.FromTensor(Tensor.FromBytes(new VectorType(output.ElementType, target.Lanes), output.BytesBuffer.ToArray(), inputOrt.Shape.SkipLast(target.Lanes.Count).Select(i => i).ToArray())); } else { diff --git a/modules/Nncase.Modules.CPU/Evaluator/CPU/PackedBinary.cs b/modules/Nncase.Modules.CPU/Evaluator/CPU/PackedBinary.cs index 8a1d5fa5a..b1d21a797 100644 --- a/modules/Nncase.Modules.CPU/Evaluator/CPU/PackedBinary.cs +++ b/modules/Nncase.Modules.CPU/Evaluator/CPU/PackedBinary.cs @@ -95,7 +95,7 @@ public Cost Visit(ICostEvaluateContext context, PackedBinary target) private IRType Visit(PackedBinary target, TensorType a, TensorType b) { var rank = System.Math.Max(a.Shape.Rank, b.Shape.Rank); - var outShape = new int[rank]; + var outShape = new long[rank]; var lhsOrginShape = a.Shape.ToValueArray(); var rhsOrginShape = b.Shape.ToValueArray(); for (int i = 0; i < target.LhsPackedAxes.Count; i++) @@ -127,7 +127,7 @@ private IRType Visit(PackedBinary target, TensorType a, TensorType b) case ( >= 0, >= 0): switch (lhsOrginShape[aAxis], rhsOrginShape[bAxis]) { - case (int l, int r) when l == r: + case (long l, long r) when l == r: outShape[rank + i] = a.Shape[aAxis].FixedValue; orginKinds[rank + i] = DimKind.E; break; diff --git a/modules/Nncase.Modules.CPU/Evaluator/CPU/PackedLayerNorm.cs b/modules/Nncase.Modules.CPU/Evaluator/CPU/PackedLayerNorm.cs index 9d7b2ee3a..f25e626f1 100644 --- a/modules/Nncase.Modules.CPU/Evaluator/CPU/PackedLayerNorm.cs +++ b/modules/Nncase.Modules.CPU/Evaluator/CPU/PackedLayerNorm.cs @@ -41,7 +41,7 @@ public IValue Visit(IEvaluateContext context, PackedLayerNorm target) var outputTensor = OrtKISharp.Tensor.MakeTensor(new Memory(output), OrtDataType.Float, unpackedInput.Shape); outputTensor = CPUEvaluatorUtility.RepackTensor(outputTensor, lanes, target.PackedAxes, target.PadedNums); - return Value.FromTensor(Tensor.FromBytes(new VectorType(DataTypes.Float32, lanes), outputTensor.BytesBuffer.ToArray(), outputTensor.Shape.SkipLast(target.PackedAxes.Count).Select(i => (int)i).ToArray())); + return Value.FromTensor(Tensor.FromBytes(new VectorType(DataTypes.Float32, lanes), outputTensor.BytesBuffer.ToArray(), outputTensor.Shape.SkipLast(target.PackedAxes.Count).Select(i => i).ToArray())); } /// diff --git a/modules/Nncase.Modules.CPU/Evaluator/CPU/PackedMatMul.cs b/modules/Nncase.Modules.CPU/Evaluator/CPU/PackedMatMul.cs index b99e07ba6..28fe7a1b3 100644 --- a/modules/Nncase.Modules.CPU/Evaluator/CPU/PackedMatMul.cs +++ b/modules/Nncase.Modules.CPU/Evaluator/CPU/PackedMatMul.cs @@ -22,20 +22,20 @@ public IValue Visit(IEvaluateContext context, PackedMatMul target) var outRank = context.CurrentCall.CheckedShape.Rank; var outLanes = Array.Empty(); - var outShape = Array.Empty(); + var outShape = Array.Empty(); var axes = Array.Empty(); var (lm, lk) = target.TransposeA ? (lhs.Rank - target.RhsPackedAxes.Count - 1, lhs.Rank - target.RhsPackedAxes.Count - 2) : (lhs.Rank - target.LhsPackedAxes.Count - 2, lhs.Rank - target.LhsPackedAxes.Count - 1); var (rk, rn) = target.TransposeB ? (rhs.Rank - target.RhsPackedAxes.Count - 1, rhs.Rank - target.RhsPackedAxes.Count - 2) : (rhs.Rank - target.RhsPackedAxes.Count - 2, rhs.Rank - target.RhsPackedAxes.Count - 1); if (target.LhsPackedAxes.Count == 0 && target.RhsPackedAxes.Count == 1) { outLanes = new[] { (int)rhs.Shape[^1] }; - outShape = new[] { (int)lhs.Shape[lm], (int)rhs.Shape[rn] }; + outShape = new[] { lhs.Shape[lm], rhs.Shape[rn] }; axes = new[] { outRank - 1 }; } else if (target.LhsPackedAxes.Count == 1 && target.RhsPackedAxes.Count == 0) { outLanes = new[] { (int)lhs.Shape[^1] }; - outShape = new[] { (int)lhs.Shape[lm], (int)rhs.Shape[rn] }; + outShape = new[] { lhs.Shape[lm], rhs.Shape[rn] }; axes = new[] { outRank - 2 }; } else if (target.LhsPackedAxes.Count == 1 && target.RhsPackedAxes.Count == 1) @@ -51,24 +51,24 @@ public IValue Visit(IEvaluateContext context, PackedMatMul target) axes = new[] { outRank - 2, outRank - 1 }; } - outShape = new[] { (int)lhs.Shape[lm], (int)rhs.Shape[rn] }; + outShape = new[] { lhs.Shape[lm], rhs.Shape[rn] }; } else if (target.LhsPackedAxes.Count == 1 && target.RhsPackedAxes.Count == 2) { outLanes = new[] { (int)rhs.Shape[^1] }; - outShape = new[] { (int)lhs.Shape[lm], (int)rhs.Shape[rn] }; + outShape = new[] { lhs.Shape[lm], rhs.Shape[rn] }; axes = new[] { outRank - 1 }; } else if (target.LhsPackedAxes.Count == 2 && target.RhsPackedAxes.Count == 1) { outLanes = new[] { (int)lhs.Shape[^2] }; - outShape = new[] { (int)lhs.Shape[lm], (int)rhs.Shape[rn] }; + outShape = new[] { lhs.Shape[lm], rhs.Shape[rn] }; axes = new[] { outRank - 2 }; } else if (target.LhsPackedAxes.Count == 2 && target.RhsPackedAxes.Count == 2) { outLanes = new[] { (int)lhs.Shape[^2], (int)rhs.Shape[^1] }; - outShape = new[] { (int)lhs.Shape[lm], (int)rhs.Shape[rn] }; + outShape = new[] { lhs.Shape[lm], rhs.Shape[rn] }; axes = new[] { outRank - 2, outRank - 1 }; } else @@ -79,7 +79,7 @@ public IValue Visit(IEvaluateContext context, PackedMatMul target) var maxRank = System.Math.Max(lhs.Shape.Length - target.LhsPackedAxes.Count, rhs.Shape.Length - target.RhsPackedAxes.Count); outShape = Enumerable.Repeat(1L, maxRank - lhs.Shape.Length + target.LhsPackedAxes.Count).Concat(lhs.Shape.SkipLast(2 + target.LhsPackedAxes.Count)). Zip(Enumerable.Repeat(1L, maxRank - rhs.Shape.Length + target.RhsPackedAxes.Count).Concat(rhs.Shape.SkipLast(2 + target.RhsPackedAxes.Count))). - Select(p => (int)System.Math.Max(p.First, p.Second)). + Select(p => System.Math.Max(p.First, p.Second)). Concat(outShape).ToArray(); foreach (var axis in target.LhsPackedAxes.Reverse()) diff --git a/modules/Nncase.Modules.CPU/Evaluator/CPU/PackedReduce.cs b/modules/Nncase.Modules.CPU/Evaluator/CPU/PackedReduce.cs index 4a7a67c08..c4c4a481f 100644 --- a/modules/Nncase.Modules.CPU/Evaluator/CPU/PackedReduce.cs +++ b/modules/Nncase.Modules.CPU/Evaluator/CPU/PackedReduce.cs @@ -19,7 +19,7 @@ public sealed class PackedReduceEvaluator : IEvaluator, ITypeInfer public IValue Visit(IEvaluateContext context, PackedReduce target) { var input = context.GetOrtArgumentValue(target, PackedReduce.Input); - var inshape = input.Shape.SkipLast(target.PackedAxes.Count).Select(i => (int)i).ToArray(); + var inshape = input.Shape.SkipLast(target.PackedAxes.Count).Select(i => i).ToArray(); var inlanes = input.Shape.TakeLast(target.PackedAxes.Count).Select(i => (int)i).ToArray(); var unpackedInput = CPUEvaluatorUtility.UnpackTensor(input, target.PackedAxes, target.PadedNums, out _); var axes = target.Axes.Select(i => (long)i).ToArray(); diff --git a/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Binary.cs b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Binary.cs index 9deea14fe..7f270fcd0 100644 --- a/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Binary.cs +++ b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Binary.cs @@ -20,7 +20,7 @@ public MicroKernelInfo Visit(Binary op, MicroKernelContext context) { var domain = context.AccessMaps[0].Domains; var primitives = Enumerable.Repeat(1, domain.Length).ToArray(); - var multipliers = Enumerable.Repeat(new ValueRange(1, int.MaxValue), domain.Length).ToArray(); + var multipliers = Enumerable.Repeat(new ValueRange(1, int.MaxValue), domain.Length).ToArray(); var bufferInfos = new MicroKernelBufferInfo[context.BufferShapes.Length]; var opt = (ICpuTargetOptions)context.TargetOptions; bufferInfos[0] = new(opt.MemoryBandWidths[1], opt.MemoryBandWidths[1], MicroKernelBufferInfo.BufferState.Read); diff --git a/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Matmul.cs b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Matmul.cs index 8bf092a53..acc8bba15 100644 --- a/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Matmul.cs +++ b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Matmul.cs @@ -17,7 +17,7 @@ public MicroKernelInfo Visit(Matmul op, MicroKernelContext context) { var domain = context.AccessMaps[0].Domains; var primitives = Enumerable.Repeat(1, domain.Length).ToArray(); - var multipliers = Enumerable.Repeat(new ValueRange(1, int.MaxValue), domain.Length).ToArray(); + var multipliers = Enumerable.Repeat(new ValueRange(1, int.MaxValue), domain.Length).ToArray(); var (k, m, n) = (context.BufferShapes[0][^1], context.BufferShapes[2][^2], context.BufferShapes[2][^1]); var (lpack, rpack) = PackedMatMul.GetPackKind(op.LhsPackedAxes, op.RhsPackedAxes); diff --git a/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Pack.cs b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Pack.cs index 941867571..597045be0 100644 --- a/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Pack.cs +++ b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Pack.cs @@ -23,7 +23,7 @@ public MicroKernelInfo Visit(Nncase.TIR.CPU.Pack op, MicroKernelContext context) { var domain = context.AccessMaps[0].Domains; var primitives = Enumerable.Repeat(1, domain.Length).ToArray(); - var multipliers = Enumerable.Repeat(new ValueRange(1, int.MaxValue), domain.Length).ToArray(); + var multipliers = Enumerable.Repeat(new ValueRange(1, int.MaxValue), domain.Length).ToArray(); var bufferInfos = new MicroKernelBufferInfo[context.BufferShapes.Length]; var opt = (ICpuTargetOptions)context.TargetOptions; bufferInfos[0] = new(opt.MemoryBandWidths[1], opt.MemoryBandWidths[1], MicroKernelBufferInfo.BufferState.Read); diff --git a/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/PackedBinary.cs b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/PackedBinary.cs index da67de2e9..daa8876ac 100644 --- a/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/PackedBinary.cs +++ b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/PackedBinary.cs @@ -25,7 +25,7 @@ public MicroKernelInfo Visit(PackedBinary op, MicroKernelContext context) { var domain = context.AccessMaps[0].Domains; var primitives = Enumerable.Repeat(1, domain.Length).ToArray(); - var multipliers = Enumerable.Repeat(new ValueRange(1, int.MaxValue), domain.Length).ToArray(); + var multipliers = Enumerable.Repeat(new ValueRange(1, int.MaxValue), domain.Length).ToArray(); var bufferInfos = new MicroKernelBufferInfo[context.BufferShapes.Length]; var opt = (ICpuTargetOptions)context.TargetOptions; bufferInfos[0] = new(opt.MemoryBandWidths[1], opt.MemoryBandWidths[1], MicroKernelBufferInfo.BufferState.Read); diff --git a/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Swish.cs b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Swish.cs index d3a2ab790..f126b70e3 100644 --- a/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Swish.cs +++ b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Swish.cs @@ -23,7 +23,7 @@ public MicroKernelInfo Visit(Swish swish, MicroKernelContext context) { var domain = context.AccessMaps[0].Domains; var primitives = Enumerable.Repeat(1, domain.Length).ToArray(); - var multipliers = Enumerable.Repeat(new ValueRange(1, int.MaxValue), domain.Length).ToArray(); + var multipliers = Enumerable.Repeat(new ValueRange(1, int.MaxValue), domain.Length).ToArray(); var bufferInfos = new MicroKernelBufferInfo[context.BufferShapes.Length]; var opt = (ICpuTargetOptions)context.TargetOptions; bufferInfos[0] = new(opt.MemoryBandWidths[1], opt.MemoryBandWidths[1], MicroKernelBufferInfo.BufferState.Read); diff --git a/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Unary.cs b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Unary.cs index 2f298ef60..0b4309e9a 100644 --- a/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Unary.cs +++ b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Unary.cs @@ -23,7 +23,7 @@ public MicroKernelInfo Visit(Unary op, MicroKernelContext context) { var domain = context.AccessMaps[0].Domains; var primitives = Enumerable.Repeat(1, domain.Length).ToArray(); - var multipliers = Enumerable.Repeat(new ValueRange(1, int.MaxValue), domain.Length).ToArray(); + var multipliers = Enumerable.Repeat(new ValueRange(1, int.MaxValue), domain.Length).ToArray(); var bufferInfos = new MicroKernelBufferInfo[context.BufferShapes.Length]; var opt = (ICpuTargetOptions)context.TargetOptions; bufferInfos[0] = new(opt.MemoryBandWidths[1], opt.MemoryBandWidths[1], MicroKernelBufferInfo.BufferState.Read); diff --git a/modules/Nncase.Modules.CPU/IR/CPU/PackedReduce.cs b/modules/Nncase.Modules.CPU/IR/CPU/PackedReduce.cs index 31a21c060..616664635 100644 --- a/modules/Nncase.Modules.CPU/IR/CPU/PackedReduce.cs +++ b/modules/Nncase.Modules.CPU/IR/CPU/PackedReduce.cs @@ -25,7 +25,7 @@ public sealed partial class PackedReduce : Op public IRArray PadedNums { get; } - public static (int[] OutPackAxes, int[] OutPadNums, int[] OutLanes, int[] OutShape) ComputeOutputInfo(PackedReduce target, int[] inShape, int[] inLanes) + public static (int[] OutPackAxes, int[] OutPadNums, int[] OutLanes, long[] OutShape) ComputeOutputInfo(PackedReduce target, long[] inShape, int[] inLanes) { var packedAxes = target.PackedAxes.ToList(); var padedNums = target.PadedNums.ToList(); diff --git a/modules/Nncase.Modules.CPU/Passes/CPUFunctionPartitionPass.cs b/modules/Nncase.Modules.CPU/Passes/CPUFunctionPartitionPass.cs index 815f0d1c1..1c5e64256 100644 --- a/modules/Nncase.Modules.CPU/Passes/CPUFunctionPartitionPass.cs +++ b/modules/Nncase.Modules.CPU/Passes/CPUFunctionPartitionPass.cs @@ -31,9 +31,9 @@ protected override Task RunCoreAsync(IRModule module, RunPassContext c var ctx = new GraphPartition.GraphContext(); var convertor = new GraphPartition.GraphConvertor(x => x switch { - Call call => (call.Target is IR.CPU.Boxing || call.CheckedType is DistributedType) ? true : false, - IR.Tuple tp => tp.Fields.ToArray().Any(f => f is Call { Target: IR.CPU.Boxing } b && b.CheckedType is TensorType) ? false : true, - _ => throw new NotSupportedException(), + Call call => (call.Target is IR.CPU.Boxing || call.CheckedType is DistributedType) ? Compat.COMPATIBLE : Compat.INCOMPATIBLE, + IR.Tuple tp => tp.Fields.ToArray().Any(f => f is Call { Target: IR.CPU.Boxing } b && b.CheckedType is TensorType) ? Compat.INCOMPATIBLE : Compat.COMPATIBLE, + _ => Compat.INCOMPATIBLE, }); convertor.Visit(pre.Body, ctx); diff --git a/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs b/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs index 74423538a..dbc8de767 100644 --- a/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs +++ b/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs @@ -113,12 +113,29 @@ public static void MemoryExtractConstrains(CpModel model, IReadOnlyDictionary GetLeafCandidateBoxings(Expr expr, IEnumerable placements) { - return placements.Select( - placement => - Utilities.DistributedUtility.GetLeafCandidateNDSBPs((TensorType)expr.CheckedType, placement). - Select(ndsbp => - IR.F.CPU.Boxing(expr, new DistributedType((TensorType)expr.CheckedType, ndsbp, placement)))). - SelectMany(e => e).ToArray(); + if (expr.CheckedType is InvalidType) + { + return [expr]; + } + + if (expr is IR.Tuple tuple) + { + return tuple.Fields.ToArray(). + Select(e => IsDistributed(e.CheckedType) ? [e] : GetLeafCandidateBoxings(e, placements)). + CartesianProduct(). + Select(fs => new IR.Tuple(fs.ToArray())). + ToArray(); + } + else + { + // Don't use expr.CheckedTensorType + return placements.Select( + placement => + Utilities.DistributedUtility.GetLeafCandidateNDSBPs((TensorType)expr.CheckedType, placement). + Select(ndsbp => + IR.F.CPU.Boxing(expr, new DistributedType((TensorType)expr.CheckedType, ndsbp, placement)))). + SelectMany(e => e).ToArray(); + } } public static IReadOnlyList> GetDiverseCandidateSBPs(DistributedType distributedType, IEnumerable placements) @@ -408,7 +425,7 @@ protected override Dictionary> VisitLeafCall(Call expr) } // TODO: refactor here - if (expr.Target is not ScatterND && expr.Target is not Boxing && (expr.CheckedType is TensorType or DistributedType) && expr.CheckedShape.IsFixed && !expr.CheckedShape.ToValueArray().Contains(0) && results.Count == 1 && results.First().Key is DistributedType dt && dt.NdSBP.All(sbp => sbp is SBPBroadCast)) + if (expr.Target is not ScatterND && expr.Target is not Boxing && (expr.CheckedType is TensorType or DistributedType) && expr.CheckedShape.All(x => x != 0) && results.Count == 1 && results.First().Key is DistributedType dt && dt.NdSBP.All(sbp => sbp is SBPBroadCast)) { return expr.Arguments.ToArray(). Select(Visit). @@ -428,6 +445,13 @@ protected override Dictionary> VisitLeafCall(Call expr) return results; } + private static bool IsDistributed(IRType type) => type switch + { + DistributedType => true, + TupleType t => t.All(IsDistributed), + _ => false, + }; + private Dictionary> VisitLeafArgument(ParameterKind parameterKind, Expr expr, bool isSupported) { var updateBuckets = (Dictionary> buckets, IEnumerable equivalents) => @@ -499,6 +523,11 @@ private Dictionary> VisitLeafArgument(ParameterKind parameter } } + if (!buckets.Any()) + { + throw new InvalidOperationException(); + } + return buckets; } @@ -524,10 +553,6 @@ private IEnumerable BuildEquivalCalls(Op target, Expr[] args) var valid = call.InferenceType(); if (!valid) { - // 1. dispose current call - using var pinner = new ExprPinner(args); - call.Dispose(); - if (target is Reshape) { // the reshape need force boxing. @@ -537,11 +562,7 @@ private IEnumerable BuildEquivalCalls(Op target, Expr[] args) foreach (var boxing in Utilities.DistributedUtility.GetLeafCandidateNDSBPs(tensorType, inType.Placement). Select(ndsbp => IR.F.CPU.Boxing(args[0], new DistributedType(tensorType, ndsbp, inType.Placement), true))) { - if (boxing.CheckedType is InvalidType) - { - boxing.Dispose(); - } - else + if (boxing.CheckedType is not InvalidType) { calls.Add(boxing); } @@ -580,11 +601,7 @@ private IEnumerable BuildEquivalCalls(Op target, Expr[] args) var extraBoxings = partialBoxings.Any() ? partialBoxings.Select(getExtraBoxings).SelectMany(i => i) : getExtraBoxings(call); foreach (var boxing in extraBoxings) { - if (boxing.CheckedType is InvalidType) - { - boxing.Dispose(); - } - else + if (boxing.CheckedType is not InvalidType) { calls.Add(boxing); } @@ -592,6 +609,7 @@ private IEnumerable BuildEquivalCalls(Op target, Expr[] args) } } + // GC.Collect(); return calls; } @@ -615,7 +633,7 @@ private IReadOnlyList GetReBoxings(Expr expr) candidateNdsbps[i] = new List { SBP.B }; for (int axis = 0; axis < tensorType.Shape.Rank; axis++) { - if (tensorType.Shape[axis] is { IsFixed: true, Value: int s } && Utilities.DistributedUtility.IsDivideExactly(s, type.Placement.Hierarchy[i])) + if (tensorType.Shape[axis] is { IsFixed: true, FixedValue: long s } && Utilities.DistributedUtility.IsDivideExactly(s, type.Placement.Hierarchy[i])) { candidateNdsbps[i].Add(SBP.S(axis)); } @@ -629,13 +647,6 @@ private IReadOnlyList GetReBoxings(Expr expr) Select(disttype => IR.F.CPU.Boxing(expr, disttype)).ToArray(); } - private bool IsDistributed(IRType type) => type switch - { - DistributedType => true, - TupleType t => t.All(IsDistributed), - _ => false, - }; - private Expr InstertTerminator(Expr expr) { Expr CreateFinalBoxing(Expr e, DistributedType type) @@ -673,6 +684,7 @@ private EClass Unions(EGraph graph, IEnumerable equivalents) private void BranchCut() { + GC.Collect(); bool changed = true; while (changed) { @@ -689,11 +701,6 @@ private void BranchCut() { throw new InvalidOperationException("this item can't have more than zero users!"); } - - using (new ExprPinner(item.Operands.ToArray())) - { - item.Dispose(); - } } buket.Clear(); diff --git a/modules/Nncase.Modules.CPU/Passes/Distributed/CustomOpScheme.cs b/modules/Nncase.Modules.CPU/Passes/Distributed/CustomOpScheme.cs index 5cf26928b..bab9e5e22 100644 --- a/modules/Nncase.Modules.CPU/Passes/Distributed/CustomOpScheme.cs +++ b/modules/Nncase.Modules.CPU/Passes/Distributed/CustomOpScheme.cs @@ -5,7 +5,7 @@ namespace Nncase.Passes.Distributed; public record class CustomOpScheme(string Version, string Model, CustomOpScheme.Node[] Outputs) { - public record class Node(string? Name, string Op, int[][] Shape, IR.SBP[][] SBP, ulong Cost, string CSourcePath) + public record class Node(string? Name, string Op, long[][] Shape, IR.SBP[][] SBP, ulong Cost, string CSourcePath) { } } diff --git a/modules/Nncase.Modules.CPU/Passes/Rules/CPU/Affine/LowerBinary.cs b/modules/Nncase.Modules.CPU/Passes/Rules/CPU/Affine/LowerBinary.cs index 31375dbfe..87412d984 100644 --- a/modules/Nncase.Modules.CPU/Passes/Rules/CPU/Affine/LowerBinary.cs +++ b/modules/Nncase.Modules.CPU/Passes/Rules/CPU/Affine/LowerBinary.cs @@ -48,7 +48,7 @@ public LowerBinary(string moduleKind = CPUTarget.Kind) case ( >= 0, >= 0): switch (lhsShape[lhsi], rhsShape[rhsi]) { - case (int a, int b) when a == b: + case (long a, long b) when a == b: lhsRes[lhsi] = new AffineRange(domains[i].Offset, domains[i].Extent); rhsRes[rhsi] = new AffineRange(domains[i].Offset, domains[i].Extent); break; @@ -132,7 +132,7 @@ public LowerPackedBinary(string moduleKind = CPUTarget.Kind) case ( >= 0, >= 0): switch (lhsShape[lhsi], rhsShape[rhsi]) { - case (int a, int b) when a == b: + case (long a, long b) when a == b: lhsRes[lhsi] = new AffineRange(domains[i].Offset, domains[i].Extent); rhsRes[rhsi] = new AffineRange(domains[i].Offset, domains[i].Extent); break; diff --git a/modules/Nncase.Modules.CPU/Passes/Rules/CPU/Affine/LowerMatMul.cs b/modules/Nncase.Modules.CPU/Passes/Rules/CPU/Affine/LowerMatMul.cs index 812d06364..0a13baa4d 100644 --- a/modules/Nncase.Modules.CPU/Passes/Rules/CPU/Affine/LowerMatMul.cs +++ b/modules/Nncase.Modules.CPU/Passes/Rules/CPU/Affine/LowerMatMul.cs @@ -46,7 +46,7 @@ public LowerMatmul(string moduleKind = CPUTarget.Kind) case ( >= 0, >= 0): switch (lhsShape[lhsi], rhsShape[rhsi]) { - case (int a, int b) when a == b: + case (long a, long b) when a == b: lhsRes[lhsi] = new AffineRange(domains[i].Offset, domains[i].Extent); rhsRes[rhsi] = new AffineRange(domains[i].Offset, domains[i].Extent); break; diff --git a/modules/Nncase.Modules.CPU/Passes/Rules/CPU/GraphPartition.cs b/modules/Nncase.Modules.CPU/Passes/Rules/CPU/GraphPartition.cs index 662371143..8536519e3 100644 --- a/modules/Nncase.Modules.CPU/Passes/Rules/CPU/GraphPartition.cs +++ b/modules/Nncase.Modules.CPU/Passes/Rules/CPU/GraphPartition.cs @@ -92,7 +92,7 @@ Expr BuildNewInput(Expr parameter) case Const c: if (parameter is TensorConst { Value: Tensor { Shape.IsScalar: true } } tc) { - newInput = Const.FromTensor(Tensor.FromBytes(tc.CheckedDataType, tc.Value.BytesBuffer.ToArray(), new[] { 1 })); + newInput = Const.FromTensor(Tensor.FromBytes(tc.CheckedDataType, tc.Value.BytesBuffer.ToArray(), [1])); } else { diff --git a/modules/Nncase.Modules.CPU/Passes/Rules/CPU/PackRule.cs b/modules/Nncase.Modules.CPU/Passes/Rules/CPU/PackRule.cs index e74db57b8..35af8bc4d 100644 --- a/modules/Nncase.Modules.CPU/Passes/Rules/CPU/PackRule.cs +++ b/modules/Nncase.Modules.CPU/Passes/Rules/CPU/PackRule.cs @@ -386,7 +386,7 @@ private void AddCandidate(RuleContext context, IR.CPU.PackedMatMul.PackKind lhsP } } - private sealed record RuleContext(List Results, Expr Lhs, Expr Rhs, Expr Candidate, IReadOnlyList LhsShape, IReadOnlyList RhsShape) + private sealed record RuleContext(List Results, Expr Lhs, Expr Rhs, Expr Candidate, IReadOnlyList LhsShape, IReadOnlyList RhsShape) { } } @@ -486,7 +486,7 @@ void AddCandidate(int[] lhsPackedAxes, int[] rhsPackedAxes, int[] lhsLanes, int[ return rets; } - public IEnumerable GeneratePackAxes(int[] shape) + public IEnumerable GeneratePackAxes(long[] shape) { if (shape.Length == 0 || (shape.Length == 1 && shape[0] == 1)) { @@ -689,7 +689,7 @@ public PackConv2D(int rank, int lane) IsTensorConst("groups"), IsTensorConst("fusedClamp")); - public static Expr AddCandidate(Expr input, Expr weights, Expr bias, int[] strides, int[] padding, int[] wShape, int[] outShape) + public static Expr AddCandidate(Expr input, Expr weights, Expr bias, int[] strides, int[] padding, int[] wShape, long[] outShape) { var col = IR.F.CPU.Im2col(input, new[] { wShape[2], wShape[3] }, strides, padding); var newW = IR.F.Tensors.Reshape(weights, new[] { wShape[0], wShape[1] * wShape[2] * wShape[3] }); @@ -704,7 +704,7 @@ public static Expr AddCandidate(Expr input, Expr weights, Expr bias, int[] strid return IR.F.Tensors.Transpose(IR.F.Tensors.Reshape(add, new[] { outShape[1], outShape[0], outShape[2], outShape[3] }), new[] { 1, 0, 2, 3 }); } - public static Expr AddPackedCandidate(Expr input, Expr weights, Expr bias, int[] strides, int[] padding, int[] wShape, int[] outShape, int lane) + public static Expr AddPackedCandidate(Expr input, Expr weights, Expr bias, int[] strides, int[] padding, int[] wShape, long[] outShape, int lane) { var col = IR.F.CPU.Im2col(IR.F.CPU.Pack(input, new[] { lane }, new[] { 1 }), new[] { wShape[2], wShape[3] }, strides, padding, new[] { 1 }, new[] { 0 }); var newW = IR.F.Tensors.Reshape(IR.F.CPU.Pack(weights, new[] { lane }, new[] { 1 }), new[] { wShape[0], wShape[1] / lane * wShape[2] * wShape[3] }); @@ -733,7 +733,7 @@ public override List GetReplaceCandidates(IMatchResult result, RunPassCont var dilation = ((TensorConst)result["dilation"]).Value.ToArray(); var groups = ((TensorConst)result["groups"]).Value.ToScalar(); var fusedClamp = ((TensorConst)result["fusedClamp"]).Value.ToArray(); - var wShape = weights.CheckedShape.ToValueArray(); + var wShape = weights.CheckedShape.ToValueArray().ToInts(); var outShape = ((Expr)result[Pattern]).CheckedShape.ToValueArray(); if (groups != 1 || wShape[1] % Lane != 0 || dilation[0] != 1 || dilation[1] != 1 || fusedClamp[0] != float.NegativeInfinity || fusedClamp[1] != float.PositiveInfinity) { @@ -764,7 +764,7 @@ public override List GetReplaceCandidates(IMatchResult result, RunPassCont var rets = new List(); var input = (Expr)result["input"]; - var newShape = ((TensorConst)result["newShape"]).Value.ToArray(); + var newShape = ((TensorConst)result["newShape"]).Value.ToArray(); var inShape = input.CheckedShape.ToValueArray(); // 1. find the mapping transforms diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/DeviceToTIRVisitor.cs b/modules/Nncase.Modules.CPU/Passes/Tile/DeviceToTIRVisitor.cs index 05e6aa0bf..3493e8013 100644 --- a/modules/Nncase.Modules.CPU/Passes/Tile/DeviceToTIRVisitor.cs +++ b/modules/Nncase.Modules.CPU/Passes/Tile/DeviceToTIRVisitor.cs @@ -243,7 +243,7 @@ private ISequentialBuilder LowerMatmul(Call call, MatMul op, AffineM var tileShape = GetTile(call); var fullShape = GetShape(call); - Expr[] PostProcessAffineMap(List iters, IReadOnlyList inShape, IReadOnlyList outShape) + Expr[] PostProcessAffineMap(List iters, IReadOnlyList inShape, IReadOnlyList outShape) { var ralign = outShape.Count - inShape.Count; for (int i = outShape.Count - 1; i >= 0; i--) @@ -451,7 +451,7 @@ private ISequentialBuilder LowerBinary(Call call, Binary op, AffineM var rhsRegion = GetBufferRegion(call.Arguments[1], (TIR.Buffer inBuffer) => new BufferRegion(inBuffer, outRegion.Region)); TileScope.CurrentBlock.Alloc(outRegion.Buffer); - Expr[] PostProcessAffineMap(List iters, IReadOnlyList inShape, IReadOnlyList outShape) + Expr[] PostProcessAffineMap(List iters, IReadOnlyList inShape, IReadOnlyList outShape) { var ralign = outShape.Count - inShape.Count; for (int i = outShape.Count - 1; i >= 0; i--) @@ -604,9 +604,9 @@ private Expr[] ComputeIndcies(TIR.Buffer top, Expr[] loopvars, AffineMap rootMap return newLoopvars; } - private IReadOnlyList GetTile(Expr expr) => _tileMemo[expr].TileShape; + private IReadOnlyList GetTile(Expr expr) => _tileMemo[expr].TileShape; - private IReadOnlyList GetShape(Expr expr) => _tileMemo[expr].OutShape; + private IReadOnlyList GetShape(Expr expr) => _tileMemo[expr].OutShape; private BufferRegion GetBufferRegion(Expr expr, Func createFunc) { diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/FusionChecker.cs b/modules/Nncase.Modules.CPU/Passes/Tile/FusionChecker.cs index f1abef8a0..dd4583b3a 100644 --- a/modules/Nncase.Modules.CPU/Passes/Tile/FusionChecker.cs +++ b/modules/Nncase.Modules.CPU/Passes/Tile/FusionChecker.cs @@ -30,7 +30,7 @@ public sealed class NodeInfo : IDisposable private readonly ExprPinner _pinner; private readonly TIR.Buffer? _buffer; - public NodeInfo(TIR.Buffer? buffer, int[] tileShape, int[] outShape) + public NodeInfo(TIR.Buffer? buffer, long[] tileShape, long[] outShape) { _buffer = buffer; TileShape = tileShape; @@ -47,9 +47,9 @@ public NodeInfo(TIR.Buffer? buffer, int[] tileShape, int[] outShape) public TIR.Buffer Buffer => _buffer!; - public IReadOnlyList OutShape { get; } + public IReadOnlyList OutShape { get; } - public int[] TileShape { get; set; } + public long[] TileShape { get; set; } public void Dispose() => _pinner.Dispose(); } @@ -60,10 +60,10 @@ public sealed record TileFragment(BucketCondition Condition, IReadOnlyDictionary public sealed class FusionChecker { - private readonly List> _initTileList; + private readonly List> _initTileList; private IReadOnlyList? _checkedResult; - public FusionChecker(List> initTileList) + public FusionChecker(List> initTileList) { _initTileList = initTileList; } @@ -181,16 +181,16 @@ public IReadOnlyList Check(Expr root) return _checkedResult = conditions.Zip(tileMaps).Select(p => new TileFragment(p.First, p.Second)).ToList(); } - private static List> GetCandidateKs(Dictionary bucket) + private static List> GetCandidateKs(Dictionary bucket) { - var allKs = new Dictionary>(); + var allKs = new Dictionary>(); foreach (var kv in bucket) { if (kv.Key is Call { Target: MatMul op } call) { var k = bucket[call[op.Parameters.First()]].Last(); - var ks = new List(); - for (int i = 32; i < k; i += 32) + var ks = new List(); + for (long i = 32; i < k; i += 32) { ks.Add(i); } @@ -200,20 +200,20 @@ private static List> GetCandidateKs(Dictionary>> ret = new[] { Enumerable.Empty>() }; + IEnumerable>> ret = new[] { Enumerable.Empty>() }; foreach (var kvp in allKs) { ret = from seq in ret from item in kvp.Value - select seq.Concat(new[] { new KeyValuePair(kvp.Key, item) }); + select seq.Concat(new[] { new KeyValuePair(kvp.Key, item) }); } return ret.Select(seq => seq.ToDictionary(kv => kv.Key, kv => kv.Value)).ToList(); } - private (List> Buckets, List Conditions) GetSplitBuckets() + private (List> Buckets, List Conditions) GetSplitBuckets() { - var buckets = new Dictionary>(); + var buckets = new Dictionary>(); foreach (var s in GetCandidateBuckets()) { buckets.Add(s, new()); @@ -341,7 +341,7 @@ from item in kvp.Value } } - List> ret = new(); + List> ret = new(); List conditions = new(); foreach (BucketCondition s in GetCandidateBuckets()) { @@ -379,7 +379,7 @@ private IEnumerable GetCandidateBuckets() => Select(p => p.ToArray()). Select(a => new BucketCondition(a[0], a[1], a[2])); - private bool TryAllocate(Dictionary tileMap, Dictionary bucket, bool finalAllocate = false) + private bool TryAllocate(Dictionary tileMap, Dictionary bucket, bool finalAllocate = false) { var tileList = new List>(); var exprs = ExprCollector.Collect(_initTileList.Last().Key).Where(e => e is not Op); @@ -402,7 +402,7 @@ private bool TryAllocate(Dictionary tileMap, Dictionary TryAllocate(List> tileList, Dictionary bucket, bool finalAllocate = false) + private Dictionary TryAllocate(List> tileList, Dictionary bucket, bool finalAllocate = false) { // TODO: // 1. 支持不同数据类型的检查 @@ -519,7 +519,7 @@ void UpdateLifeness(int start, Expr expr, TIR.Buffer buffer, bool updateEnd) return ret; } - private void Visit(Call expr, Dictionary tileMap, Dictionary bucketMap, List> candidateKs, int k = -1) + private void Visit(Call expr, Dictionary tileMap, Dictionary bucketMap, List> candidateKs, int k = -1) { switch (expr.Target) { @@ -537,7 +537,7 @@ private void Visit(Call expr, Dictionary tileMap, Dictionary tileMap, Dictionary bucketMap, List> candidateKs, int k = -1) + private void VisitIdenity(Call call, Dictionary tileMap, Dictionary bucketMap, List> candidateKs, int k = -1) { var inTileShape = tileMap[call].TileShape; var input = call.Arguments[0]; @@ -560,16 +560,16 @@ private void VisitIdenity(Call call, Dictionary tileMap, Diction } } - private void VisitMatmul(IR.Math.MatMul op, Call call, Dictionary tileMap, Dictionary bucketMap, List> candidateKs, int k) + private void VisitMatmul(IR.Math.MatMul op, Call call, Dictionary tileMap, Dictionary bucketMap, List> candidateKs, int k) { var lhs = call.Arguments[0]; var rhs = call.Arguments[1]; var outTileShape = tileMap[call].TileShape; - var inTileShapeA = Enumerable.Repeat(1, lhs.CheckedShape.Rank).ToArray(); + var inTileShapeA = Enumerable.Repeat(1L, lhs.CheckedShape.Rank).ToArray(); inTileShapeA[^2] = outTileShape[^2]; inTileShapeA[^1] = candidateKs[k][call]; - var inTileShapeB = Enumerable.Repeat(1, rhs.CheckedShape.Rank).ToArray(); + var inTileShapeB = Enumerable.Repeat(1L, rhs.CheckedShape.Rank).ToArray(); inTileShapeB[^2] = candidateKs[k][call]; inTileShapeB[^1] = outTileShape[^1]; @@ -610,7 +610,7 @@ private void VisitMatmul(IR.Math.MatMul op, Call call, Dictionary tileMap, Dictionary bucketMap, List> candidateKs, int k) + private void VisitBinary(IR.Math.Binary op, Call call, Dictionary tileMap, Dictionary bucketMap, List> candidateKs, int k) { var lhs = call.Arguments[0]; var rhs = call.Arguments[1]; diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/PrimTileVisitor.cs b/modules/Nncase.Modules.CPU/Passes/Tile/PrimTileVisitor.cs index 15da06055..f1b898163 100644 --- a/modules/Nncase.Modules.CPU/Passes/Tile/PrimTileVisitor.cs +++ b/modules/Nncase.Modules.CPU/Passes/Tile/PrimTileVisitor.cs @@ -15,7 +15,7 @@ public PrimTileVisitor() Count = 0; } - public List> TileList { get; } + public List> TileList { get; } public List> NameList { get; } @@ -34,9 +34,9 @@ protected override Unit VisitLeafCall(Call expr) { var lhs = expr.Arguments[0]; var rhs = expr.Arguments[1]; - var inTileShapeA = Enumerable.Repeat(1, lhs.CheckedShape.Rank).ToArray(); + var inTileShapeA = Enumerable.Repeat(1L, lhs.CheckedShape.Rank).ToArray(); Array.Fill(inTileShapeA, 32, inTileShapeA.Length - 2, 2); - var inTileShapeB = Enumerable.Repeat(1, rhs.CheckedShape.Rank).ToArray(); + var inTileShapeB = Enumerable.Repeat(1L, rhs.CheckedShape.Rank).ToArray(); Array.Fill(inTileShapeB, 32, inTileShapeB.Length - 2, 2); if (!(lhs is Var or TensorConst)) @@ -61,7 +61,7 @@ protected override Unit VisitLeafCall(Call expr) NameList.Add(new(rhs, nameof(IR.Math.MatMul) + "_" + Count.ToString() + "_rhs")); } - var outTileShape = Enumerable.Repeat(1, expr.CheckedShape.Rank).ToArray(); + var outTileShape = Enumerable.Repeat(1L, expr.CheckedShape.Rank).ToArray(); outTileShape[^1] = inTileShapeB[^1]; outTileShape[^2] = inTileShapeA[^2]; TileList.Add(new(expr, outTileShape)); @@ -73,7 +73,7 @@ protected override Unit VisitLeafCall(Call expr) case IR.Math.Unary or IR.CPU.Store or IR.CPU.Load: { var input = expr.Arguments[0]; - var inTileShape = Enumerable.Repeat(1, input.CheckedShape.Rank).ToArray(); + var inTileShape = Enumerable.Repeat(1L, input.CheckedShape.Rank).ToArray(); inTileShape[^1] = 32; if (!(input is Var or TensorConst)) @@ -98,9 +98,9 @@ protected override Unit VisitLeafCall(Call expr) { var lhs = expr.Arguments[0]; var rhs = expr.Arguments[1]; - var inTileShapeA = Enumerable.Repeat(1, lhs.CheckedShape.Rank).ToArray(); + var inTileShapeA = Enumerable.Repeat(1L, lhs.CheckedShape.Rank).ToArray(); inTileShapeA[^1] = 32; - var inTileShapeB = Enumerable.Repeat(1, rhs.CheckedShape.Rank).ToArray(); + var inTileShapeB = Enumerable.Repeat(1L, rhs.CheckedShape.Rank).ToArray(); inTileShapeB[^1] = 32; if (!(lhs is Var or TensorConst)) @@ -125,7 +125,7 @@ protected override Unit VisitLeafCall(Call expr) NameList.Add(new(rhs, nameof(IR.Math.Binary) + "_" + Count + "_rhs")); } - var outTileShape = Enumerable.Repeat(1, expr.CheckedShape.Rank).ToArray(); + var outTileShape = Enumerable.Repeat(1L, expr.CheckedShape.Rank).ToArray(); outTileShape[^1] = 32; TileList.Add(new(expr, outTileShape)); NameList.Add(new(expr, nameof(IR.Math.Binary) + "_" + Count)); diff --git a/modules/Nncase.Modules.CPU/Targets/CPUTargetOptions.cs b/modules/Nncase.Modules.CPU/Targets/CPUTargetOptions.cs index 28dd50284..67cda6e9b 100644 --- a/modules/Nncase.Modules.CPU/Targets/CPUTargetOptions.cs +++ b/modules/Nncase.Modules.CPU/Targets/CPUTargetOptions.cs @@ -74,7 +74,7 @@ public class CpuTargetOptions : ICpuTargetOptions [DisplayName("--hierarchy-names")] [Description("the name identify of hierarchies.")] - [DefaultValue("b")] + [DefaultValue("t")] public string HierarchyNames { get; set; } = "t"; [DisplayName("--hierarchy-sizes")] diff --git a/modules/Nncase.Modules.CPU/Utilities/PackUtility.cs b/modules/Nncase.Modules.CPU/Utilities/PackUtility.cs index 88de3b7d5..2819d327d 100644 --- a/modules/Nncase.Modules.CPU/Utilities/PackUtility.cs +++ b/modules/Nncase.Modules.CPU/Utilities/PackUtility.cs @@ -7,7 +7,7 @@ namespace Nncase.Utilities; public static class PackUtility { - public static Expr PadForPack(Expr input, int[] shape, int[] packedAxes, int[] lanes, Expr value, out int[] padNums) + public static Expr PadForPack(Expr input, long[] shape, int[] packedAxes, int[] lanes, Expr value, out int[] padNums) { var isPadded = false; var pads = new int[shape.Length, 2]; @@ -16,7 +16,7 @@ public static Expr PadForPack(Expr input, int[] shape, int[] packedAxes, int[] l var axis = packedAxes[i]; if (shape[axis] % lanes[i] != 0) { - pads[axis, 1] = MathUtility.AlignUp(shape[axis], lanes[i]) - shape[axis]; + pads[axis, 1] = (int)(MathUtility.AlignUp(shape[axis], lanes[i]) - shape[axis]); isPadded = true; } } @@ -35,7 +35,7 @@ public static Expr PadForPack(Expr input, int[] shape, int[] packedAxes, int[] l return input; } - public static Expr SliceForPack(Expr input, int[] shape, int[] padNums) + public static Expr SliceForPack(Expr input, long[] shape, int[] padNums) { bool isPadded = false; var ends = shape.ToArray(); @@ -44,7 +44,7 @@ public static Expr SliceForPack(Expr input, int[] shape, int[] padNums) isPadded = true; } - return isPadded ? IR.F.Tensors.Slice(input, Enumerable.Repeat(0, shape.Length).ToArray(), ends, shape.Length) : input; + return isPadded ? IR.F.Tensors.Slice(input, Enumerable.Repeat(0L, shape.Length).ToArray(), ends, shape.Length) : input; } /// @@ -54,11 +54,11 @@ public static Expr SliceForPack(Expr input, int[] shape, int[] padNums) /// new shape. /// mat. /// bool. - public static bool TryGetShapeMapMatrix(int[] inShape, int[] newShape, out int[,] mat) + public static bool TryGetShapeMapMatrix(long[] inShape, long[] newShape, out int[,] mat) { - int ProdIn(int[,] cmat, int i) + long ProdIn(int[,] cmat, int i) { - var prod = 1; + long prod = 1; for (int j = 0; j < inShape.Length; j++) { var v = cmat[i, j] * inShape[j]; @@ -71,9 +71,9 @@ int ProdIn(int[,] cmat, int i) return prod; } - int ProdOut(int[,] cmat, int j) + long ProdOut(int[,] cmat, int j) { - var prod = 1; + long prod = 1; for (int i = 0; i < newShape.Length; i++) { var v = cmat[i, j] * newShape[i]; diff --git a/modules/Nncase.Modules.StackVM/CodeGen/StackVM/CodegenVisitor.cs b/modules/Nncase.Modules.StackVM/CodeGen/StackVM/CodegenVisitor.cs index 999ca8755..657b771a5 100644 --- a/modules/Nncase.Modules.StackVM/CodeGen/StackVM/CodegenVisitor.cs +++ b/modules/Nncase.Modules.StackVM/CodeGen/StackVM/CodegenVisitor.cs @@ -406,15 +406,15 @@ private TextSnippet Visit(TensorConst expr, Tensor tensor) { if (!_context.ConstSymbols.TryGetValue(expr, out var buffer)) { - buffer = WriteRdata(tensor.BytesBuffer, _alignment); + buffer = WriteRdata(tensor, _alignment); _context.ConstSymbols.Add(expr, buffer); } // stack: dtype shape strides buffer var snippet = BeginTextSnippet(expr); LeaGp(_rdataGpid, buffer); - LdStrides(tensor.Strides); - LdShape(tensor.Dimensions); + LdStrides(tensor.Strides.ToInts()); + LdShape(tensor.Dimensions.ToInts()); LdDataType(tensor.ElementType); Emitter.LdTensor(); return snippet; diff --git a/src/Native/src/kernels/stackvm/optimized/matmul.cpp b/src/Native/src/kernels/stackvm/optimized/matmul.cpp index bb0d1befd..79529c288 100644 --- a/src/Native/src/kernels/stackvm/optimized/matmul.cpp +++ b/src/Native/src/kernels/stackvm/optimized/matmul.cpp @@ -24,10 +24,10 @@ using namespace nncase::kernels; using namespace nncase::kernels::stackvm; using namespace nncase::kernels::stackvm::optimized; -result optimized::matmul(typecode_t typecode, const gsl::byte *input_a, const gsl::byte *input_b, - gsl::byte *output, - gsl::span in_a_shape, - gsl::span in_b_shape, +result optimized::matmul(typecode_t typecode, const std::byte *input_a, const std::byte *input_b, + std::byte *output, + std::span in_a_shape, + std::span in_b_shape, [[maybe_unused]] kernel_context &context) noexcept { return stackvm::reference::matmul(typecode, input_a, input_b, output, in_a_shape, in_b_shape, context); } diff --git a/src/Native/src/kernels/stackvm/optimized/opt_ops.h b/src/Native/src/kernels/stackvm/optimized/opt_ops.h index d27668776..0d1789167 100644 --- a/src/Native/src/kernels/stackvm/optimized/opt_ops.h +++ b/src/Native/src/kernels/stackvm/optimized/opt_ops.h @@ -140,9 +140,9 @@ unary(typecode_t dtype, runtime::stackvm::unary_op_t op, const std::byte *in, kernel_context &context = default_kernel_context()) noexcept; NNCASE_API result -matmul(typecode_t typecode, const gsl::byte *input_a, const gsl::byte *input_b, - gsl::byte *output, gsl::span in_a_shape, - gsl::span in_b_shape, +matmul(typecode_t typecode, const std::byte *input_a, const std::byte *input_b, + std::byte *output, std::span in_a_shape, + std::span in_b_shape, [[maybe_unused]] kernel_context &context) noexcept; // template diff --git a/src/Native/src/kernels/stackvm/reference/layer_norm.cpp b/src/Native/src/kernels/stackvm/reference/layer_norm.cpp index 4c822031e..4ab65b3e1 100644 --- a/src/Native/src/kernels/stackvm/reference/layer_norm.cpp +++ b/src/Native/src/kernels/stackvm/reference/layer_norm.cpp @@ -30,10 +30,6 @@ static void layernorm_impl(int inner_size, const T *src, const T *scale, for (auto i = 0; i < inner_size; i++) mean1 += src[i] / inner_size; } - if (use_mean) { - for (auto i = 0; i < inner_size; i++) - mean1 += src[i] / inner_size; - } std::vector sub(inner_size, 0); for (auto i = 0; i < inner_size; i++) @@ -87,7 +83,6 @@ result layer_norm_impl2(const T *input, T *output, const T *scale, return layer_norm_impl2(IN_CAST(type, input), OUT_CAST(type, output), \ IN_CAST(type, scale), IN_CAST(type, bias), \ in_shape, axis, epsilon, use_mean) - in_shape, axis, epsilon, use_mean) #define TYPE_SELECT_LAYER_NORM(_typecode, _impl) \ switch (_typecode) { \ diff --git a/src/Nncase.Cli/Program.cs b/src/Nncase.Cli/Program.cs index c19a0ae6b..1659605cc 100644 --- a/src/Nncase.Cli/Program.cs +++ b/src/Nncase.Cli/Program.cs @@ -127,7 +127,7 @@ private static CompileOptions ParseCompileOptions(System.CommandLine.Invocation. }, }; -#if true +#if false compileOptions.ShapeBucketOptions.Enable = true; compileOptions.ShapeBucketOptions.RangeInfo = new() { { "history_len", (0, 64) }, { "seq_len", (1, 64) } }; compileOptions.ShapeBucketOptions.SegmentsCount = 2; diff --git a/src/Nncase.Core/CompilerServices.cs b/src/Nncase.Core/CompilerServices.cs index c0f95db26..ea9d2e38a 100644 --- a/src/Nncase.Core/CompilerServices.cs +++ b/src/Nncase.Core/CompilerServices.cs @@ -224,6 +224,8 @@ public interface ICompilerServicesProvider IEGraph ERewrite(IEGraph expr, IEnumerable rules, RunPassContext options); MicroKernelInfo GetOpMicroKernelInfo(Op op, MicroKernelContext context); + + Expr SimplifyForDimension(Expr value); } internal interface ICompilerServicesProviderInternal @@ -528,6 +530,8 @@ public static void DumpPatternIR(Expr expr, string prefix, string dumpDir) => /// Target. public static ITarget GetTarget(string name) => Provider.GetTarget(name); + public static Expr SimplifyForDimension(Expr value) => Provider.SimplifyForDimension(value); + internal static DryIoc.IContainer CreateScope() { var container = (DryIoc.IContainer)_serviceProvider!; @@ -554,6 +558,7 @@ internal class CompilerServicesProvider : ICompilerServicesProvider, ICompilerSe private readonly IMetricEvaluateProvider _metricEvaluateProvider; private readonly IMatchProvider _matchProvider; private readonly IRewriteProvider _rewriteProvider; + private readonly ISimplifyProvider _simplifyProvider; private readonly IEGraphMatchProvider _eGraphMatchProvider; private readonly IEGraphRewriteProvider _eGraphrewriteProvider; private readonly ITargetProvider _targetProvider; @@ -569,6 +574,7 @@ public CompilerServicesProvider( IDataTypeServiceProvider dataTypeServiceProvider, IMatchProvider matchProvider, IRewriteProvider rewriteProvider, + ISimplifyProvider simplifyProvider, IEGraphMatchProvider eGraphMatchProvider, IEGraphRewriteProvider eGraphrewriteProvider, ITargetProvider targetProvider, @@ -584,6 +590,7 @@ public CompilerServicesProvider( DataTypeService = dataTypeServiceProvider; _matchProvider = matchProvider; _rewriteProvider = rewriteProvider; + _simplifyProvider = simplifyProvider; _eGraphMatchProvider = eGraphMatchProvider; _eGraphrewriteProvider = eGraphrewriteProvider; _targetProvider = targetProvider; @@ -719,4 +726,6 @@ public IEGraph ERewrite(IEGraph graph, IEnumerable rules, RunPassC } public MicroKernelInfo GetOpMicroKernelInfo(Op op, MicroKernelContext context) => _microKernelInfoGetter.GetInfo(op, context); + + public Expr SimplifyForDimension(Expr value) => _simplifyProvider.SimplifyForDimension(value); } diff --git a/src/Nncase.Core/DataType.cs b/src/Nncase.Core/DataType.cs index 17c3783ce..dfd42a0e3 100644 --- a/src/Nncase.Core/DataType.cs +++ b/src/Nncase.Core/DataType.cs @@ -254,5 +254,5 @@ public VectorType(DataType elemType, params int[] lanes) _ => throw new NotSupportedException(), }; - public override int SizeInBytes => ElemType.SizeInBytes * (int)TensorUtilities.GetProduct(Lanes.ToArray()); + public override int SizeInBytes => ElemType.SizeInBytes * (int)TensorUtilities.GetProduct(TensorUtilities.ToLongs(Lanes.ToArray())); } diff --git a/src/Nncase.Core/Evaluator/ITypeInferenceContext.cs b/src/Nncase.Core/Evaluator/ITypeInferenceContext.cs index 9078b683c..2be88b3cf 100644 --- a/src/Nncase.Core/Evaluator/ITypeInferenceContext.cs +++ b/src/Nncase.Core/Evaluator/ITypeInferenceContext.cs @@ -23,6 +23,8 @@ public interface ITypeInferenceContext /// The argument expression. Expr GetArgument(Op op, ParameterInfo parameter); + Expr GetDimensionArgument(Op op, ParameterInfo parameter) => CompilerServices.SimplifyForDimension(GetArgument(op, parameter)); + /// /// Get arguments expression. /// diff --git a/src/Nncase.Core/Evaluator/Metric.cs b/src/Nncase.Core/Evaluator/Metric.cs index 958bcb58b..052ea5443 100644 --- a/src/Nncase.Core/Evaluator/Metric.cs +++ b/src/Nncase.Core/Evaluator/Metric.cs @@ -185,7 +185,7 @@ public static class MetricUtility public static UInt128 ResizeCubicFLOPs => 8; - public static UInt128 GetFLOPs(IRType type, int scale = 1) + public static UInt128 GetFLOPs(IRType type, long scale = 1) { return type switch { diff --git a/src/Nncase.Core/IR/Const.cs b/src/Nncase.Core/IR/Const.cs index d22090e09..fa761ab30 100644 --- a/src/Nncase.Core/IR/Const.cs +++ b/src/Nncase.Core/IR/Const.cs @@ -139,7 +139,7 @@ public static TensorConst FromTensor(Tensor tensor) /// /// convert shape to const expr. /// - public static Const FromShape(Shape shape) => FromTensor(Tensor.From(shape.ToValueArray())); + public static Const FromShape(Shape shape) => FromTensor(Tensor.From(shape.ToValueArray())); /// /// Convert value to const expr. diff --git a/src/Nncase.Core/IR/Dimension.cs b/src/Nncase.Core/IR/Dimension.cs index 83b1395ec..08f53edb7 100644 --- a/src/Nncase.Core/IR/Dimension.cs +++ b/src/Nncase.Core/IR/Dimension.cs @@ -6,6 +6,8 @@ using System.Linq; using System.Text; using System.Threading.Tasks; +using Nncase.Passes; +using Nncase.Passes.Mutators; namespace Nncase.IR { @@ -23,31 +25,52 @@ public enum DimensionKind : byte /// Fixed dimesnion. /// Fixed, + + /// + /// Used for shape pattern. + /// + Any, } /// /// Shape dimension. /// - public struct Dimension : IEquatable + public sealed class Dimension : IEquatable { - /// - /// An unknown dimension. - /// - public static readonly Dimension Unknown; + public static readonly Dimension Any = new Dimension(); + + private readonly long? _fixedValue; + private readonly Expr? _exprValue; /// - /// Initializes a new instance of the struct. + /// Initializes a new instance of the class. /// /// Dimension value. - public Dimension(int value) + public Dimension(long value) { - if (value < 0) + Kind = DimensionKind.Fixed; + _fixedValue = value; + } + + public Dimension(Expr value) + { + value = CompilerServices.SimplifyForDimension(value); + if (value is TensorConst tc) { - throw new ArgumentOutOfRangeException(nameof(value), "Dimension should not be lower than 0."); + Kind = DimensionKind.Fixed; + _fixedValue = tc.Value.ToScalar(); } + else + { + Kind = DimensionKind.Unknown; + _exprValue = value; + } + } - Kind = DimensionKind.Fixed; - Value = value; + private Dimension() + { + Kind = DimensionKind.Any; + _exprValue = new Var("Any", DataTypes.Int64); } /// @@ -58,32 +81,40 @@ public Dimension(int value) /// /// Gets value. /// - public int? Value { get; } + public Expr Value => _exprValue ?? _fixedValue!.Value; /// /// Gets FixedValue. /// - public int FixedValue + public long FixedValue { - get => Value ?? + get => _fixedValue ?? throw new InvalidOperationException("Only Can Get It When Shape Is Fixed !"); } /// /// Gets a value indicating whether unknown. /// - public bool IsUnknown => Kind == DimensionKind.Unknown; + public bool IsUnknown => Kind is DimensionKind.Unknown or DimensionKind.Any; /// /// Gets a value indicating whether fixed. /// public bool IsFixed => Kind == DimensionKind.Fixed; + public bool IsAny => Kind == DimensionKind.Any; + /// /// Convert to a fixed . /// /// Dimension value. - public static implicit operator Dimension(int value) => new(value); + public static implicit operator Dimension(long value) => new(value); + + /// + /// Convert to a expression. + /// + /// Dimension value. + public static implicit operator Dimension(Expr value) => new(value); public static bool operator ==(Dimension left, Dimension right) { @@ -98,29 +129,63 @@ public int FixedValue public static Dimension operator +(Dimension lhs, Dimension rhs) => (lhs.IsFixed, rhs.IsFixed) switch { (true, true) => lhs.FixedValue + rhs.FixedValue, - (_, _) => Dimension.Unknown, + (_, _) => new Dimension(lhs.Value + rhs.Value), }; - public static Dimension operator +(Dimension lhs, int rhs) => lhs.IsFixed ? lhs.FixedValue + rhs : Unknown; + public static Dimension operator +(Dimension lhs, int rhs) => lhs.IsFixed ? lhs.FixedValue + rhs : new Dimension(lhs.Value + rhs); public static Dimension operator -(Dimension lhs, Dimension rhs) => (lhs.IsFixed, rhs.IsFixed) switch { (true, true) => lhs.FixedValue - rhs.FixedValue, - (_, _) => Dimension.Unknown, + (_, _) => new Dimension(lhs.Value - rhs.Value), }; public static Dimension operator *(Dimension lhs, Dimension rhs) => (lhs.IsFixed, rhs.IsFixed) switch { (true, true) => lhs.FixedValue * rhs.FixedValue, - (_, _) => Dimension.Unknown, + (_, _) => new Dimension(lhs.Value * rhs.Value), }; public static Dimension operator /(Dimension lhs, Dimension rhs) => (lhs.IsFixed, rhs.IsFixed) switch { (true, true) => lhs.FixedValue / rhs.FixedValue, - (_, _) => Dimension.Unknown, + (_, _) => new Dimension(lhs.Value / rhs.Value), }; + public static Dimension Abs(Dimension value) + { + if (value.IsFixed) + { + return System.Math.Abs(value.FixedValue); + } + + return IR.F.Math.Abs(value.Value); + } + + public static Dimension Clamp(Dimension value, Dimension min, Dimension max) + { + if (value.IsFixed && min.IsFixed && max.IsFixed) + { + return System.Math.Clamp(value.FixedValue, min.FixedValue, max.FixedValue); + } + + return IR.F.Math.Clamp(value.Value, min.Value, max.Value); + } + + public static Dimension CeilDiv(Dimension lhs, Dimension rhs) + { + if (lhs.IsFixed && rhs.IsFixed) + { + return (lhs.FixedValue + rhs.FixedValue - 1) / rhs.FixedValue; + } + + return IR.F.Math.CeilDiv(lhs.Value, rhs.Value); + } + + // public static Dimension Unknown(string? name = null) => new Dimension(name is null ? new Var(DataTypes.Int64) : new Var(name, DataTypes.Int64)); + + public static Dimension Unknown(string? name = null) => Any; + /// public override string ToString() { @@ -134,21 +199,26 @@ public override bool Equals(object? obj) } /// - public bool Equals(Dimension other) + public bool Equals(Dimension? other) { - return Kind == other.Kind && - Value == other.Value; + return other is not null && (Kind, other.Kind) switch + { + (DimensionKind.Any, DimensionKind.Any) => true, + (DimensionKind.Unknown, DimensionKind.Unknown) => Value == other.Value, + (DimensionKind.Fixed, DimensionKind.Fixed) => FixedValue == other.FixedValue, + (_, _) => false, + }; } /// public override int GetHashCode() { - return HashCode.Combine(Kind, Value); + return IsFixed ? HashCode.Combine(Kind, FixedValue) : HashCode.Combine(Kind, Value); } - public bool HasFixedValue(Predicate predicate) + public bool HasFixedValue(Predicate predicate) { - return Value.HasValue && predicate(Value.Value); + return IsFixed && predicate(FixedValue); } public bool IsAssignableFrom(Dimension dimension) diff --git a/src/Nncase.Core/IR/Expr.cs b/src/Nncase.Core/IR/Expr.cs index 179b500f7..760f39709 100644 --- a/src/Nncase.Core/IR/Expr.cs +++ b/src/Nncase.Core/IR/Expr.cs @@ -8,6 +8,7 @@ using System.Diagnostics; using System.Linq; using System.Reactive; +using System.Runtime.CompilerServices; using System.Text; using System.Threading.Tasks; using CommunityToolkit.HighPerformance.Helpers; @@ -29,13 +30,12 @@ public class IRMetadata /// /// Expression. /// -public abstract partial class Expr : IDisposable +public abstract partial class Expr { private readonly Expr[] _operands; - private readonly ConcurrentDictionary _users = new(ReferenceEqualityComparer.Instance); + private readonly ConditionalWeakTable _users = new(); private IRType? _checkedType; private int? _hashCodeCache; - private bool _disposedValue; internal Expr(IEnumerable operands) { @@ -160,17 +160,17 @@ public DataType CheckedDataType /// /// Gets users. /// - public IEnumerable Users => EnsureAlive()._users.Keys; + public IEnumerable Users => _users.Select(x => x.Key).ToArray(); /// /// Gets operands. /// - public ReadOnlySpan Operands => EnsureAlive()._operands; + public ReadOnlySpan Operands => _operands; /// /// Gets a value indicating whether the expr is alive. /// - public bool IsAlive => !_disposedValue; + public bool IsAlive => Users.Any(); /// /// Gets or sets raw checked type. @@ -219,30 +219,15 @@ public override bool Equals(object? obj) /// public sealed override int GetHashCode() => _hashCodeCache ??= GetHashCodeCore(); - public void Dispose() - { - Dispose(disposing: true); - GC.SuppressFinalize(this); - } - - public void DisposeIfNoUsers() - { - if (_users.Keys.Count == 0) - { - Dispose(); - } - } - internal void AddUser(Expr user) { - EnsureAlive(); Trace.Assert(!ReferenceEquals(this, user)); - _users.TryAdd(user.EnsureAlive(), default); + _users.TryAdd(user, default); } internal void RemoveUser(Expr user) { - _users.Remove(user, out _); + _users.Remove(user); } internal void ReplaceOperand(int index, Expr newOperand) @@ -262,7 +247,6 @@ internal void ReplaceAllUsesWith(Expr newOperand) internal void ReplaceScopedUsesWith(Expr newOperand, IReadOnlySet? scope) { - EnsureAlive(); if (!ReferenceEquals(this, newOperand)) { foreach (var user in Users.ToArray()) @@ -293,20 +277,6 @@ protected virtual int GetHashCodeCore() return HashCode.Combine(GetType(), HashCode.Combine(Operands)); } - protected virtual void Dispose(bool disposing) - { - if (!_disposedValue) - { - foreach (var operand in _operands) - { - operand.RemoveUser(this); - operand.DisposeIfNoUsers(); - } - - _disposedValue = true; - } - } - private bool IsDescendantOf(Expr other, Dictionary visited) { if (visited.TryGetValue(this, out var result)) @@ -380,14 +350,4 @@ private void InvalidateUsersHashCodeCache() user.InvalidateHashCodeCache(); } } - - private Expr EnsureAlive() - { - if (_disposedValue) - { - throw new ObjectDisposedException(null); - } - - return this; - } } diff --git a/src/Nncase.Core/IR/ExprRewriter.cs b/src/Nncase.Core/IR/ExprRewriter.cs index e09ac4b7e..ebef1b92e 100644 --- a/src/Nncase.Core/IR/ExprRewriter.cs +++ b/src/Nncase.Core/IR/ExprRewriter.cs @@ -70,20 +70,8 @@ protected override void VisitOperands(Expr expr, TContext context) private void DCE(Expr root, ExprScope exprScope) { - using var exprPin = new ExprPinner(root); - foreach (var expr in ExprMemo) - { - expr.Key.DisposeIfNoUsers(); - expr.Value.DisposeIfNoUsers(); - } - - foreach (var expr in exprScope.Exprs) - { - if (expr is not ExprUser) - { - expr.DisposeIfNoUsers(); - } - } + // using var exprPin = new ExprPinner(root); + // GC.Collect(); } } diff --git a/src/Nncase.Core/IR/IRHelpers.cs b/src/Nncase.Core/IR/IRHelpers.cs index a2b776ff8..f57a5df13 100644 --- a/src/Nncase.Core/IR/IRHelpers.cs +++ b/src/Nncase.Core/IR/IRHelpers.cs @@ -18,6 +18,8 @@ public static class IRHelpers public static void DCE(BaseFunction function) { + GC.Collect(); + return; using var exprPin = new ExprPinner(function); var exprs = ExprCollector.Collect(function); var users = new HashSet(ReferenceEqualityComparer.Instance); @@ -40,12 +42,7 @@ void AddUsers(Expr expr) AddUsers(expr); } - foreach (var user in users) - { - user.DisposeIfNoUsers(); - } - - DCESanity(function); + GC.Collect(); } [Conditional("DEBUG")] diff --git a/src/Nncase.Core/IR/IRModule.cs b/src/Nncase.Core/IR/IRModule.cs index b694ac5f4..6df5676ea 100644 --- a/src/Nncase.Core/IR/IRModule.cs +++ b/src/Nncase.Core/IR/IRModule.cs @@ -78,7 +78,7 @@ public void Replace(int index, BaseFunction function) if (old.IsAlive) { old.ReplaceAllUsesWith(function); - old.DisposeIfNoUsers(); + GC.Collect(); } old = function; @@ -97,11 +97,7 @@ public void Remove(BaseFunction function) } function.RemoveUser(_exprUser); - if (function.IsAlive) - { - function.DisposeIfNoUsers(); - } - _functions.RemoveAt(index); + GC.Collect(); } } diff --git a/src/Nncase.Core/IR/Math/Functional.cs b/src/Nncase.Core/IR/Math/Functional.cs index 96bb9b4cd..de4e18d36 100644 --- a/src/Nncase.Core/IR/Math/Functional.cs +++ b/src/Nncase.Core/IR/Math/Functional.cs @@ -241,6 +241,13 @@ public static Call Clamp(Expr input, ValueRange range) /// Result expression. public static Call Max(Expr lhs, Expr rhs) => Binary(BinaryOp.Max, lhs, rhs); + /// + /// Call max. + /// + /// value operands. + /// Result expression. + public static Expr Max(IEnumerable values) => values.Aggregate(Max); + /// /// Call pow. /// diff --git a/src/Nncase.Core/IR/NN/SpaceToBatch.cs b/src/Nncase.Core/IR/NN/SpaceToBatch.cs index 17aa97544..8fe5fd82b 100644 --- a/src/Nncase.Core/IR/NN/SpaceToBatch.cs +++ b/src/Nncase.Core/IR/NN/SpaceToBatch.cs @@ -31,5 +31,5 @@ public sealed partial class SpaceToBatch : Op /// /// Gets paddings. /// - public static readonly ParameterInfo Paddings = new(typeof(SpaceToBatch), 2, "paddings", HasShape(new[] { Dimension.Unknown, 2 }) & IsIntegral()); + public static readonly ParameterInfo Paddings = new(typeof(SpaceToBatch), 2, "paddings", HasShape(new[] { Dimension.Any, 2 }) & IsIntegral()); } diff --git a/src/Nncase.Core/IR/Shape.cs b/src/Nncase.Core/IR/Shape.cs index 3e8e63fbd..e9ab39f37 100644 --- a/src/Nncase.Core/IR/Shape.cs +++ b/src/Nncase.Core/IR/Shape.cs @@ -9,6 +9,7 @@ using System.Text; using System.Threading.Tasks; using NetFabric.Hyperlinq; +using Nncase.IR.Tensors; namespace Nncase.IR { @@ -114,6 +115,24 @@ public Shape(IEnumerable dimensions) { } + /// + /// Initializes a new instance of the class. + /// + /// Dimensions. + public Shape(IEnumerable dimensions) + : this(dimensions.Select(i => (Dimension)i).ToArray()) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// Dimensions. + public Shape(ReadOnlySpan dimensions) + : this(dimensions.AsValueEnumerable().Select(i => (Dimension)i).ToArray()) + { + } + private Shape(ShapeKind kind, IEnumerable dimensions) { Kind = kind; @@ -184,7 +203,7 @@ private Shape(ShapeKind kind, IEnumerable dimensions) /// /// Gets get Total Elements. /// - public int Size => Enumerable.Range(0, Rank).Aggregate(1, (size, i) => size * _dimensions[i].FixedValue); + public long Size => Enumerable.Range(0, Rank).Aggregate(1L, (size, i) => size * _dimensions[i].FixedValue); /// public int Count => ((IReadOnlyCollection)_dimensions).Count; @@ -195,12 +214,14 @@ private Shape(ShapeKind kind, IEnumerable dimensions) ? ((IReadOnlyList)_dimensions)[index] : ((IReadOnlyList)_dimensions)[Rank + index]; - public static implicit operator ReadOnlySpan(Shape shape) => shape._dimensions.Select(x => (int)(x.Value ?? -1)).ToArray(); + public static implicit operator ReadOnlySpan(Shape shape) => shape._dimensions.Select(x => x.FixedValue).ToArray(); public static implicit operator Shape(Dimension[] dimensions) => new Shape(dimensions); public static implicit operator Shape(int[] dimensions) => new Shape(dimensions); + public static implicit operator Shape(long[] dimensions) => new Shape(dimensions); + public static bool operator ==(Shape lhs, Shape rhs) { return lhs.Equals(rhs); @@ -216,7 +237,36 @@ private Shape(ShapeKind kind, IEnumerable dimensions) /// public static Shape Unknown(int rank) { - return new Shape(ShapeKind.HasUnknownDimension, Enumerable.Repeat(Dimension.Unknown, rank)); + return new Shape(ShapeKind.HasUnknownDimension, Enumerable.Range(0, rank).Select(x => Dimension.Unknown())); + } + + /// + /// Gets a shape with rank unknwon dimension. + /// + public static Shape FromExpr(Expr value) + { + if (value is TensorConst tc) + { + return new Shape(tc.Value.ToArray()); + } + else if (value is Call { Target: Concat } concat) + { + if (concat.Arguments[Concat.Input.Index] is Tuple tuple) + { + return new Shape(tuple.Fields); + } + } + + var shape = value.CheckedShape; + if (shape.Rank != 1 || !shape.IsFixed) + { + // throw new ArgumentException($"Invalid shape expr: {value}", nameof(value)); + return Shape.Unranked; + } + + var rank = (int)shape[0].FixedValue; + // return new Shape(Enumerable.Range(0, rank).Select(x => (Dimension)value[x])); + return Shape.Unknown(rank); } /// @@ -254,7 +304,7 @@ public Shape InsertAndClone(int index, IEnumerable dims) /// /// convert to the int list. /// - public List ToValueList() + public List ToValueList() { return _dimensions.Select(dim => dim.FixedValue).ToList(); } @@ -262,7 +312,7 @@ public List ToValueList() /// /// convert the int array. /// - public int[] ToValueArray() + public long[] ToValueArray() { return _dimensions.Select(dim => dim.FixedValue).ToArray(); } diff --git a/src/Nncase.Core/IR/TypePattern.cs b/src/Nncase.Core/IR/TypePattern.cs index f827cce9f..54bfe3fb7 100644 --- a/src/Nncase.Core/IR/TypePattern.cs +++ b/src/Nncase.Core/IR/TypePattern.cs @@ -128,7 +128,7 @@ public static TypePattern HasShape(Shape target_shape) => HasShape( inshape => inshape.Rank == target_shape.Rank && inshape.Zip(target_shape).All( - (dim) => dim.Second == Dimension.Unknown ? true : dim.Second == dim.First), + (dim) => dim.Second.IsAny ? true : dim.Second == dim.First), $"Shape = {target_shape}"); /// diff --git a/src/Nncase.Core/Passes/GetReplaceUtility.cs b/src/Nncase.Core/Passes/GetReplaceUtility.cs index fcca1e4fb..958060cd5 100644 --- a/src/Nncase.Core/Passes/GetReplaceUtility.cs +++ b/src/Nncase.Core/Passes/GetReplaceUtility.cs @@ -61,11 +61,11 @@ Fx WithTmpTypeImpl(Fx inputCtor) => input => return Apply(WithTmpTypeImpl, inputCtor); } - public static Fx WithTmp4DShape(Fx inputCtor, int[] originOutShape) + public static Fx WithTmp4DShape(Fx inputCtor, long[] originOutShape) { Fx WithTmpGNNEShape(Fx inCtor) => input => - ((Func)(shape => + ((Func)(shape => Reshape( inCtor(Reshape(input, Get4DGNNEShape(shape))), originOutShape)))(input.CheckedShape.ToValueArray()); @@ -73,13 +73,13 @@ Fx WithTmpGNNEShape(Fx inCtor) => return Apply(WithTmpGNNEShape, inputCtor); } - internal static int[] Get4DGNNEShape(int[] dims) + internal static int[] Get4DGNNEShape(long[] dims) { if (dims.Length > 4) { throw new InvalidOperationException("dims Length should <= 4"); } - return Enumerable.Repeat(1, 4 - dims.Length).Concat(dims).ToArray(); + return Enumerable.Repeat(1, 4 - dims.Length).Concat(dims.Select(x => checked((int)x))).ToArray(); } } diff --git a/src/Nncase.Core/Passes/IRewriteProvider.cs b/src/Nncase.Core/Passes/IRewriteProvider.cs index e38cd2a15..a6ebc9146 100644 --- a/src/Nncase.Core/Passes/IRewriteProvider.cs +++ b/src/Nncase.Core/Passes/IRewriteProvider.cs @@ -49,3 +49,8 @@ public interface IEGraphRewriteProvider /// Rewrited EGraph. IEGraph ERewrite(IEGraph eGraph, IEnumerable rules, RunPassContext options); } + +public interface ISimplifyProvider +{ + Expr SimplifyForDimension(Expr expr); +} diff --git a/src/Nncase.Core/Schedule/MicroKernelInfo.cs b/src/Nncase.Core/Schedule/MicroKernelInfo.cs index e9dc4f7e1..3eb464637 100644 --- a/src/Nncase.Core/Schedule/MicroKernelInfo.cs +++ b/src/Nncase.Core/Schedule/MicroKernelInfo.cs @@ -30,10 +30,10 @@ public enum BufferState : byte /// /// micro kernel infomation for auto tiling. /// -public record MicroKernelInfo(int[] Primitives, ValueRange[] Multipliers, MicroKernelBufferInfo[] BufferInfos, Func GetComputeCycle) +public record MicroKernelInfo(int[] Primitives, ValueRange[] Multipliers, MicroKernelBufferInfo[] BufferInfos, Func GetComputeCycle) { } -public record MicroKernelContext(Op Op, ImmutableArray AccessMaps, ImmutableArray> BufferShapes, ITargetOptions TargetOptions) +public record MicroKernelContext(Op Op, ImmutableArray AccessMaps, ImmutableArray> BufferShapes, ITargetOptions TargetOptions) { } diff --git a/src/Nncase.Core/Tensor.cs b/src/Nncase.Core/Tensor.cs index 2314ea9dc..103706790 100644 --- a/src/Nncase.Core/Tensor.cs +++ b/src/Nncase.Core/Tensor.cs @@ -65,15 +65,15 @@ public abstract partial class Tensor : IStructuralComparable, IStructuralEquatab private static readonly MethodInfo _tensorCastFunc = typeof(Tensor).GetMethod(nameof(Cast))!; - private readonly int[] _dimensions; - private readonly int[] _strides; + private readonly long[] _dimensions; + private readonly long[] _strides; /// /// Initializes a new instance of the class. /// /// Element type. /// Size of the 1-dimensional tensor. - internal Tensor(DataType elementType, int length) + internal Tensor(DataType elementType, long length) { ElementType = elementType; Shape = new Shape(length); @@ -87,12 +87,12 @@ internal Tensor(DataType elementType, int length) /// /// Element type. /// An span of integers that represent the size of each dimension of the DenseTensor to create. - internal Tensor(DataType elementType, ReadOnlySpan dimensions) + internal Tensor(DataType elementType, ReadOnlySpan dimensions) { ElementType = elementType; _dimensions = dimensions.ToArray(); Shape = dimensions.IsEmpty ? Shape.Scalar : new Shape(_dimensions); - Length = (int)TensorUtilities.GetProduct(dimensions); + Length = TensorUtilities.GetProduct(dimensions); _strides = TensorUtilities.GetStrides(dimensions); } @@ -104,12 +104,12 @@ internal Tensor(DataType elementType, ReadOnlySpan dimensions) /// /// Gets dimensions. /// - public ReadOnlySpan Dimensions => _dimensions; + public ReadOnlySpan Dimensions => _dimensions; /// /// Gets strides. /// - public ReadOnlySpan Strides => _strides; + public ReadOnlySpan Strides => _strides; /// /// Gets shape. @@ -124,14 +124,14 @@ internal Tensor(DataType elementType, ReadOnlySpan dimensions) /// /// Gets total length. /// - public int Length { get; } + public long Length { get; } /// /// Gets bytes buffer. /// public abstract Span BytesBuffer { get; } - int ICollection.Count => Length; + int ICollection.Count => checked((int)Length); bool ICollection.IsSynchronized => false; @@ -153,7 +153,7 @@ internal Tensor(DataType elementType, ReadOnlySpan dimensions) /// A one-dimensional array of integers that represent the indices specifying the /// position of the element to get. /// The value at the specified position in this Tensor. - public object this[ReadOnlySpan indices] + public object this[ReadOnlySpan indices] { get => GetValueCore(TensorUtilities.GetIndex(Strides, indices)); set => SetValueCore(TensorUtilities.GetIndex(Strides, indices), value); @@ -165,7 +165,7 @@ public object this[ReadOnlySpan indices] /// A one-dimensional array of integers that represent the indices specifying the /// position of the element to get. /// The value at the specified position in this Tensor. - public object this[params int[] indices] + public object this[params long[] indices] { get => this[indices.AsSpan()]; set => this[indices.AsSpan()] = value; @@ -180,7 +180,7 @@ public object this[params int[] indices] public static Tensor FromScalar(T value) where T : unmanaged, IEquatable { - var tensor = new Tensor(ReadOnlySpan.Empty); + var tensor = new Tensor(ReadOnlySpan.Empty); tensor[0] = value; return tensor; } @@ -192,7 +192,7 @@ public static Tensor FromScalar(T value) /// Value. /// Fill length. /// Created tensor. - public static Tensor FromScalar(T value, int length) + public static Tensor FromScalar(T value, long length) where T : unmanaged, IEquatable { var tensor = new Tensor(MemoryMarshal.CreateReadOnlySpan(ref length, 1)); @@ -207,7 +207,7 @@ public static Tensor FromScalar(T value, int length) /// Value. /// Fill dimensions. /// Created tensor. - public static Tensor FromScalar(T value, ReadOnlySpan dimensions) + public static Tensor FromScalar(T value, ReadOnlySpan dimensions) where T : unmanaged, IEquatable { var tensor = new Tensor(dimensions); @@ -221,9 +221,9 @@ public static Tensor FromScalar(T value, ReadOnlySpan dimensions) /// Start value. /// Count. /// Created tensor. - public static Tensor FromRange(int start, int count) + public static Tensor FromRange(long start, long count) { - var tensor = new Tensor(MemoryMarshal.CreateReadOnlySpan(ref count, 1)); + var tensor = new Tensor(MemoryMarshal.CreateReadOnlySpan(ref count, 1)); var buffer = tensor.Buffer.Span; for (int i = 0; i < count; i++) { @@ -233,7 +233,7 @@ public static Tensor FromRange(int start, int count) return tensor; } - public static Tensor From(DataType dataType, ITensorInitializer initializer, ReadOnlySpan dimensions) + public static Tensor From(DataType dataType, ITensorInitializer initializer, ReadOnlySpan dimensions) { var tensor = Zeros(dataType, dimensions); tensor.Initialize(initializer); @@ -249,7 +249,7 @@ public static Tensor From(DataType dataType, ITensorInitializer initializer, Rea public static Tensor From(Memory memory) where T : unmanaged, IEquatable { - var dim = memory.Length; + long dim = memory.Length; return new Tensor(memory, MemoryMarshal.CreateReadOnlySpan(ref dim, 1)); } @@ -260,7 +260,7 @@ public static Tensor From(Memory memory) /// Memory. /// Dimensions. /// Created tensor. - public static Tensor From(Memory memory, ReadOnlySpan dimensions) + public static Tensor From(Memory memory, ReadOnlySpan dimensions) where T : unmanaged, IEquatable { return new Tensor(memory, dimensions); @@ -285,7 +285,7 @@ public static Tensor From(T[] array) /// Array. /// Dimensions. /// Created tensor. - public static Tensor From(T[] array, ReadOnlySpan dimensions) + public static Tensor From(T[] array, ReadOnlySpan dimensions) where T : unmanaged, IEquatable { return From(array.AsMemory(), dimensions); @@ -298,7 +298,7 @@ public static Tensor From(T[] array, ReadOnlySpan dimensions) /// Bytes memory. /// Dimensions. /// Created tensor. - public static Tensor FromBytes(Memory memory, ReadOnlySpan dimensions) + public static Tensor FromBytes(Memory memory, ReadOnlySpan dimensions) where T : unmanaged, IEquatable { return new Tensor(memory.Cast(), dimensions); @@ -311,7 +311,7 @@ public static Tensor FromBytes(Memory memory, ReadOnlySpan dime /// Bytes memory. /// Dimensions. /// Created tensor. - public static Tensor FromBytes(DataType type, Memory memory, ReadOnlySpan dimensions) + public static Tensor FromBytes(DataType type, Memory memory, ReadOnlySpan dimensions) { return (Tensor)_tensorCreateFromBytesFunc.MakeGenericMethod(type.CLRType).Invoke(null, new object[] { memory, dimensions.ToArray() })!; } @@ -327,7 +327,7 @@ public static Tensor FromBytes(TensorType type, Memory buffer) return FromBytes(type.DType, buffer, type.Shape.ToValueArray()); } - public static Tensor FromStream(DataType type, Stream stream, ReadOnlySpan dimensions) + public static Tensor FromStream(DataType type, Stream stream, ReadOnlySpan dimensions) { var tensor = Tensor.Zeros(type, dimensions); tensor.Deserialize(stream); @@ -342,7 +342,7 @@ public static Tensor FromStream(DataType type, Stream stream, ReadOnlySpan public static unsafe Tensor FromArray(Array array) { var elemType = array.GetType().GetElementType()!; - var dims = new int[array.Rank]; + var dims = new long[array.Rank]; for (int i = 0; i < array.Rank; i++) { dims[i] = array.GetLength(i); @@ -403,14 +403,14 @@ public static Tensor FromConst(Const @const, CastMode castMode = CastMode. /// unmanaged type. /// dimensions. /// Tensor{T}. - public static Tensor Zeros(ReadOnlySpan dimensions) + public static Tensor Zeros(ReadOnlySpan dimensions) where T : unmanaged, IEquatable { var value = (T)Convert.ChangeType(0, typeof(T)); return Tensor.FromScalar(value, dimensions); } - public static Tensor Zeros(DataType dataType, ReadOnlySpan dimensions) + public static Tensor Zeros(DataType dataType, ReadOnlySpan dimensions) { return (Tensor)_tensorCreateEmptyFunc.MakeGenericMethod(dataType.CLRType).Invoke(null, new object[] { dimensions.ToArray() })!; } @@ -421,7 +421,7 @@ public static Tensor Zeros(DataType dataType, ReadOnlySpan dimensions) /// unmanaged type. /// dimensions. /// Tensor{T}. - public static Tensor Ones(ReadOnlySpan dimensions) + public static Tensor Ones(ReadOnlySpan dimensions) where T : unmanaged, IEquatable { var value = (T)Convert.ChangeType(1, typeof(T)); @@ -469,7 +469,7 @@ public IEnumerator GetEnumerator() public abstract void Serialize(Stream stream); - public abstract void Serialize(Stream baseStream, long offset, int[] shape, int[] strides); + public abstract void Serialize(Stream baseStream, long offset, long[] shape, long[] strides); int IStructuralComparable.CompareTo(object? other, IComparer comparer) { @@ -536,26 +536,26 @@ void IList.RemoveAt(int index) private protected abstract void CopyToCore(Array array, int index); - private protected abstract object GetValueCore(int index); + private protected abstract object GetValueCore(long index); - private protected abstract void SetValueCore(int index, object? value); + private protected abstract void SetValueCore(long index, object? value); private protected abstract void Initialize(ITensorInitializer initializer); - private static Tensor CreateTensorFromBytesImpl(Memory buffer, int[] dimensions) + private static Tensor CreateTensorFromBytesImpl(Memory buffer, long[] dimensions) where T : unmanaged, IEquatable { return new Tensor(buffer.Cast(), dimensions); } - private static Tensor CreateTensorFromArrayImpl(Array array, int[] dimensions) + private static Tensor CreateTensorFromArrayImpl(Array array, long[] dimensions) where T : unmanaged, IEquatable { var mmgr = new ArrayMemoryManager(array); return new Tensor(mmgr.Memory, dimensions); } - private static Tensor CreateTensorEmptyImpl(int[] dimensions) + private static Tensor CreateTensorEmptyImpl(long[] dimensions) where T : unmanaged, IEquatable { return new Tensor(dimensions); diff --git a/src/Nncase.Core/TensorOfT.cs b/src/Nncase.Core/TensorOfT.cs index 0197b014b..691947803 100644 --- a/src/Nncase.Core/TensorOfT.cs +++ b/src/Nncase.Core/TensorOfT.cs @@ -34,7 +34,7 @@ public unsafe sealed partial class Tensor : Tensor, IEnumerable, ICollecti /// Initializes a new instance of the class. /// /// Size of the 1-dimensional tensor. - public Tensor(int length) + public Tensor(long length) : base(DataType.FromType(), length) { Buffer = new T[length]; @@ -44,7 +44,7 @@ public Tensor(int length) /// Initializes a new instance of the class. /// /// An span of integers that represent the size of each dimension of the DenseTensor to create. - public Tensor(ReadOnlySpan dimensions) + public Tensor(ReadOnlySpan dimensions) : base(DataType.FromType(), dimensions) { Buffer = new T[Length]; @@ -55,7 +55,7 @@ public Tensor(ReadOnlySpan dimensions) /// /// Buffer memory. /// An span of integers that represent the size of each dimension of the DenseTensor to create. - public Tensor(Memory buffer, ReadOnlySpan dimensions) + public Tensor(Memory buffer, ReadOnlySpan dimensions) : base(DataType.FromType(), dimensions) { Trace.Assert(Length == buffer.Length); @@ -72,11 +72,11 @@ public Tensor(Memory buffer, ReadOnlySpan dimensions) /// public override Span BytesBuffer => MemoryMarshal.AsBytes(Buffer.Span); - int ICollection.Count => Length; + int ICollection.Count => checked((int)Length); bool ICollection.IsReadOnly => false; - int IReadOnlyCollection.Count => Length; + int IReadOnlyCollection.Count => checked((int)Length); T IReadOnlyList.this[int index] => GetValue(index); @@ -92,7 +92,7 @@ T IList.this[int index] /// A one-dimensional array of integers that represent the indices specifying the /// position of the element to get. /// The value at the specified position in this Tensor. - public new T this[ReadOnlySpan indices] + public new T this[ReadOnlySpan indices] { get => GetValue(TensorUtilities.GetIndex(Strides, indices)); set => SetValue(TensorUtilities.GetIndex(Strides, indices), value); @@ -104,7 +104,7 @@ T IList.this[int index] /// A one-dimensional array of integers that represent the indices specifying the /// position of the element to get. /// The value at the specified position in this Tensor. - public new T this[params int[] indices] + public new T this[params long[] indices] { get => this[indices.AsSpan()]; set => this[indices.AsSpan()] = value; @@ -134,7 +134,7 @@ public Tensor Clone() /// Type contained in the returned Tensor. /// An span of integers that represent the size of each dimension of the DenseTensor to create. /// A new tensor with the same layout as this tensor but different type and dimensions. - public Tensor CloneEmpty(ReadOnlySpan dimensions) + public Tensor CloneEmpty(ReadOnlySpan dimensions) where TResult : unmanaged, IEquatable { return new Tensor(dimensions); @@ -146,9 +146,9 @@ public Tensor CloneEmpty(ReadOnlySpan dimensions) /// /// An integer index computed as a dot-product of indices. /// The value at the specified position in this Tensor. - public T GetValue(int index) + public T GetValue(long index) { - return Buffer.Span[index]; + return Buffer.Span[checked((int)index)]; } /// @@ -156,7 +156,7 @@ public T GetValue(int index) /// /// An span of integers that represent the size of each dimension of the DenseTensor to create. /// A new tensor that reinterprets backing Buffer of this tensor with different dimensions. - public Tensor Reshape(ReadOnlySpan dimensions) + public Tensor Reshape(ReadOnlySpan dimensions) { if (Length != TensorUtilities.GetProduct(dimensions)) { @@ -172,9 +172,9 @@ public Tensor Reshape(ReadOnlySpan dimensions) /// /// An integer index computed as a dot-product of indices. /// The new value to set at the specified position in this Tensor. - public void SetValue(int index, T value) + public void SetValue(long index, T value) { - Buffer.Span[index] = value; + Buffer.Span[checked((int)index)] = value; } /// @@ -202,7 +202,7 @@ public bool Contains(T value) /// The object to locate in the . /// The index of item if found in the tensor. /// true if item is found in the ; otherwise, false. - public bool TryGetIndicesOf(T item, Span indices) + public bool TryGetIndicesOf(T item, Span indices) { if (indices.Length != Rank) { @@ -232,12 +232,12 @@ public override string GetArrayString(bool includeWhitespace = true) var builder = new StringBuilder(); - var indices = new int[Rank]; + var indices = new long[Rank]; var innerDimension = Rank - 1; var innerLength = Dimensions[innerDimension]; int indent = 0; - for (int outerIndex = 0; outerIndex < Length; outerIndex += innerLength) + for (long outerIndex = 0; outerIndex < Length; outerIndex += innerLength) { TensorUtilities.GetIndices(Strides, false, outerIndex, indices); @@ -337,10 +337,10 @@ public override void Serialize(Stream stream) SpanUtility.Serialize((ReadOnlySpan)Buffer.Span, stream); } - public override void Serialize(Stream baseStream, long offset, int[] shape, int[] strides) + public override void Serialize(Stream baseStream, long offset, long[] shape, long[] strides) { var slice = new T[TensorUtilities.GetSize(shape, strides, 1)]; - var index = new int[shape.Length]; + var index = new long[shape.Length]; void Copy(int axis) { @@ -354,8 +354,8 @@ void Copy(int axis) else { var length = shape.LastOrDefault(1); - var src = Buffer.Span.Slice((int)offset + TensorUtilities.GetIndex(Strides, index), length); - var dest = slice.AsSpan(TensorUtilities.GetIndex(strides, index), length); + var src = Buffer.Span.Slice(checked((int)(offset + TensorUtilities.GetIndex(Strides, index))), checked((int)length)); + var dest = slice.AsSpan(checked((int)TensorUtilities.GetIndex(strides, index)), checked((int)length)); src.CopyTo(dest); } } @@ -573,12 +573,12 @@ private protected override void CopyToCore(Array array, int index) } } - private protected override object GetValueCore(int index) + private protected override object GetValueCore(long index) { return GetValue(index); } - private protected override void SetValueCore(int index, object? value) + private protected override void SetValueCore(long index, object? value) { SetValue(index, (T)Convert.ChangeType(value, typeof(T))!); } @@ -642,7 +642,7 @@ private int CompareTo(Array other, IComparer comparer) } var bufferA = Buffer; - var indices = new int[Rank]; + var indices = new long[Rank]; int result = 0; for (int i = 0; i < bufferA.Length; i++) @@ -701,7 +701,7 @@ private bool Equals(Array other, IEqualityComparer comparer) } var bufferA = Buffer; - var indices = new int[Rank]; + var indices = new long[Rank]; for (int i = 0; i < bufferA.Length; i++) { diff --git a/src/Nncase.Core/TensorUtilities.cs b/src/Nncase.Core/TensorUtilities.cs index 323e34e99..ec01e0a3b 100644 --- a/src/Nncase.Core/TensorUtilities.cs +++ b/src/Nncase.Core/TensorUtilities.cs @@ -5,6 +5,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.Linq; +using System.Numerics; using System.Text; using System.Threading.Tasks; using Nncase.IR; @@ -30,17 +31,18 @@ private enum SliceStatus : uint /// /// get the product from the start index on the dimensions. /// - public static long GetProduct(ReadOnlySpan dimensions, int startIndex = 0) + public static T GetProductGeneric(ReadOnlySpan dimensions, int startIndex = 0) + where T : struct, ISignedNumber, IComparisonOperators { if (dimensions.Length == 0) { - return 1; + return T.One; } - long product = 1; + T product = T.One; for (int i = startIndex; i < dimensions.Length; i++) { - if (dimensions[i] < 0) + if (dimensions[i] < T.Zero) { throw new ArgumentOutOfRangeException($"{nameof(dimensions)}[{i}]"); } @@ -56,6 +58,16 @@ public static long GetProduct(ReadOnlySpan dimensions, int startIndex = 0) return product; } + /// + /// get the product from the start index on the dimensions. + /// + public static int GetProduct(ReadOnlySpan dimensions, int startIndex = 0) => GetProductGeneric(dimensions, startIndex); + + /// + /// get the product from the start index on the dimensions. + /// + public static long GetProduct(ReadOnlySpan dimensions, int startIndex = 0) => GetProductGeneric(dimensions, startIndex); + /// /// Get the Expr Product. /// @@ -76,7 +88,7 @@ public static Expr GetProduct(ReadOnlySpan dimensions, int startIndex = 0) return product; } - public static bool IsAscending(ReadOnlySpan values) + public static bool IsAscending(ReadOnlySpan values) { for (int i = 1; i < values.Length; i++) { @@ -89,7 +101,7 @@ public static bool IsAscending(ReadOnlySpan values) return true; } - public static bool IsDescending(ReadOnlySpan values) + public static bool IsDescending(ReadOnlySpan values) { for (int i = 1; i < values.Length; i++) { @@ -105,16 +117,17 @@ public static bool IsDescending(ReadOnlySpan values) /// /// Gets the set of strides that can be used to calculate the offset of n-dimensions in a 1-dimensional layout. /// - public static int[] GetStrides(ReadOnlySpan dimensions, bool reverseStride = false) + public static T[] GetStridesGeneric(ReadOnlySpan dimensions, bool reverseStride = false) + where T : struct, ISignedNumber { if (dimensions.IsEmpty) { - return Array.Empty(); + return Array.Empty(); } - int[] strides = new int[dimensions.Length]; + var strides = new T[dimensions.Length]; - int stride = 1; + T stride = T.One; if (reverseStride) { for (int i = 0; i < strides.Length; i++) @@ -135,6 +148,16 @@ public static int[] GetStrides(ReadOnlySpan dimensions, bool reverseStride return strides; } + /// + /// Gets the set of strides that can be used to calculate the offset of n-dimensions in a 1-dimensional layout. + /// + public static int[] GetStrides(ReadOnlySpan dimensions, bool reverseStride = false) => GetStridesGeneric(dimensions, reverseStride); + + /// + /// Gets the set of strides that can be used to calculate the offset of n-dimensions in a 1-dimensional layout. + /// + public static long[] GetStrides(ReadOnlySpan dimensions, bool reverseStride = false) => GetStridesGeneric(dimensions, reverseStride); + /// /// get strides. /// @@ -195,22 +218,23 @@ public static void SplitStrides(int[] strides, int[] splitAxes, int[] newStrides /// /// Calculates the 1-d index for n-d indices in layout specified by strides. /// - public static int GetIndex(ReadOnlySpan strides, ReadOnlySpan indices, int startFromDimension = 0) + public static T GetIndexGeneric(ReadOnlySpan strides, ReadOnlySpan indices, int startFromDimension = 0) + where T : struct, IBinaryNumber, IComparisonOperators { // Scalar if (strides.Length == 0) { - if (indices.Length != 1 || indices[0] != 0) + if (indices.Length != 1 || indices[0] != T.Zero) { throw new ArgumentOutOfRangeException(nameof(indices)); } - return 0; + return T.Zero; } Trace.Assert(strides.Length == indices.Length); - int index = 0; + T index = T.Zero; for (int i = startFromDimension; i < indices.Length; i++) { index += strides[i] * indices[i]; @@ -219,6 +243,16 @@ public static int GetIndex(ReadOnlySpan strides, ReadOnlySpan indices, return index; } + /// + /// Calculates the 1-d index for n-d indices in layout specified by strides. + /// + public static int GetIndex(ReadOnlySpan strides, ReadOnlySpan indices, int startFromDimension = 0) => GetIndexGeneric(strides, indices, startFromDimension); + + /// + /// Calculates the 1-d index for n-d indices in layout specified by strides. + /// + public static long GetIndex(ReadOnlySpan strides, ReadOnlySpan indices, int startFromDimension = 0) => GetIndexGeneric(strides, indices, startFromDimension); + /// /// get index. /// @@ -249,12 +283,12 @@ public static IR.Expr GetIndex(ReadOnlySpan strides, ReadOnlySpan /// Calculates the n-d indices from the 1-d index in a layout specificed by strides. /// - public static void GetIndices(ReadOnlySpan strides, bool reverseStride, int index, int[] indices, int startFromDimension = 0) + public static void GetIndices(ReadOnlySpan strides, bool reverseStride, long index, long[] indices, int startFromDimension = 0) { Trace.Assert(reverseStride ? IsAscending(strides) : IsDescending(strides), "Index decomposition requires ordered strides"); Trace.Assert(strides.Length == indices.Length); - int remainder = index; + long remainder = index; for (int i = startFromDimension; i < strides.Length; i++) { // reverse the index for reverseStride so that we divide by largest stride first @@ -269,12 +303,12 @@ public static void GetIndices(ReadOnlySpan strides, bool reverseStride, int /// /// Calculates the n-d indices from the 1-d index in a layout specificed by strides. /// - public static void GetIndices(ReadOnlySpan strides, bool reverseStride, int index, Span indices, int startFromDimension = 0) + public static void GetIndices(ReadOnlySpan strides, bool reverseStride, long index, Span indices, int startFromDimension = 0) { Trace.Assert(reverseStride ? IsAscending(strides) : IsDescending(strides), "Index decomposition requires ordered strides"); Trace.Assert(strides.Length == indices.Length); - int remainder = index; + long remainder = index; for (int i = startFromDimension; i < strides.Length; i++) { // reverse the index for reverseStride so that we divide by largest stride first @@ -289,14 +323,14 @@ public static void GetIndices(ReadOnlySpan strides, bool reverseStride, int /// /// Takes an 1-d index over n-d sourceStrides and recalculates it assuming same n-d coordinates over a different n-d strides. /// - public static int TransformIndexByStrides(int index, int[] sourceStrides, bool sourceReverseStride, int[] transformStrides) + public static long TransformIndexByStrides(long index, long[] sourceStrides, bool sourceReverseStride, long[] transformStrides) { Trace.Assert(index >= 0); Trace.Assert(sourceReverseStride ? IsAscending(sourceStrides) : IsDescending(sourceStrides), "Index decomposition requires ordered strides"); Trace.Assert(sourceStrides.Length == transformStrides.Length); - int transformIndex = 0; - int remainder = index; + long transformIndex = 0; + long remainder = index; for (int i = 0; i < sourceStrides.Length; i++) { @@ -316,7 +350,7 @@ public static int TransformIndexByStrides(int index, int[] sourceStrides, bool s /// /// check this dimension and strides is contiguous. /// - public static bool IsContiguous(ReadOnlySpan dimensions, ReadOnlySpan strides) + public static bool IsContiguous(ReadOnlySpan dimensions, ReadOnlySpan strides) { return System.Collections.StructuralComparisons.StructuralEqualityComparer.Equals(GetStrides(dimensions), strides.ToArray()); } @@ -324,7 +358,7 @@ public static bool IsContiguous(ReadOnlySpan dimensions, ReadOnlySpan /// /// check the dimensions selected range is contiguous. /// - public static bool IsContiguousSlice(ReadOnlySpan dimensions, ReadOnlySpan slices, out int contiguousStart) + public static bool IsContiguousSlice(ReadOnlySpan dimensions, ReadOnlySpan slices, out int contiguousStart) { if (dimensions.Length != slices.Length) { @@ -341,7 +375,7 @@ public static bool IsContiguousSlice(ReadOnlySpan dimensions, ReadOnlySpan< status = (end - start) switch { // is full - int x when x == dimensions[i] => status switch + long x when x == dimensions[i] => status switch { SliceStatus.IsSlice => x == 1 ? SliceStatus.IsSlice : @@ -353,7 +387,7 @@ public static bool IsContiguousSlice(ReadOnlySpan dimensions, ReadOnlySpan< }, // when has - int x when x > 0 && x < dimensions[i] => status switch + long x when x > 0 && x < dimensions[i] => status switch { SliceStatus.IsSlice => x == 1 ? SliceStatus.IsSlice : @@ -377,7 +411,7 @@ public static bool IsContiguousSlice(ReadOnlySpan dimensions, ReadOnlySpan< return true; } - public static bool IsContiguousSlice(ReadOnlySpan dimensions, ReadOnlySpan slices) => IsContiguousSlice(dimensions, slices, out _); + public static bool IsContiguousSlice(ReadOnlySpan dimensions, ReadOnlySpan slices) => IsContiguousSlice(dimensions, slices, out _); public static long[] ToLongs(this ReadOnlySpan ints) { @@ -397,7 +431,7 @@ public static int[] ToInts(this ReadOnlySpan longs) var ints = new int[longs.Length]; for (int i = 0; i < ints.Length; i++) { - ints[i] = (int)longs[i]; + ints[i] = checked((int)longs[i]); } return ints; @@ -405,7 +439,7 @@ public static int[] ToInts(this ReadOnlySpan longs) public static int[] ToInts(this long[] longs) => ToInts((ReadOnlySpan)longs); - public static long GetSize(Span shapes, Span strides, int elementSize) + public static long GetSize(Span shapes, Span strides, int elementSize) { long size = 0; for (int i = 0; i < shapes.Length; i++) @@ -417,10 +451,10 @@ public static long GetSize(Span shapes, Span strides, int elementSize) return size * elementSize; } - public static (long Size, int[] Strides) GetTensorSizeAndStrides(TensorType tensorType, DistributedType? distributedType) + public static (long Size, long[] Strides) GetTensorSizeAndStrides(TensorType tensorType, DistributedType? distributedType) { - int[] dims; - int[] strides; + long[] dims; + long[] strides; if (distributedType is null) { dims = tensorType.Shape.ToValueArray(); @@ -436,7 +470,7 @@ public static (long Size, int[] Strides) GetTensorSizeAndStrides(TensorType tens return (GetProduct(dims) * tensorType.DType.SizeInBytes, strides); } - public static (long Size, int[] Strides) GetTensorSizeAndStrides(IRType type) + public static (long Size, long[] Strides) GetTensorSizeAndStrides(IRType type) => type switch { TensorType tensorType => GetTensorSizeAndStrides(tensorType, null), diff --git a/src/Nncase.Core/Utilities/DistributedUtility.cs b/src/Nncase.Core/Utilities/DistributedUtility.cs index b25313791..c2033018b 100644 --- a/src/Nncase.Core/Utilities/DistributedUtility.cs +++ b/src/Nncase.Core/Utilities/DistributedUtility.cs @@ -15,11 +15,11 @@ public static IReadOnlyList> GetLeafCandidateNDSBPs(TensorType tens for (int i = 0; i < placement.Rank; i++) { var ndsbp = new List(); - if (!tensorType.Shape.ToValueArray().Contains(0)) + if (tensorType.Shape.All(x => x.IsUnknown || x.FixedValue != 0)) { for (int axis = 0; axis < tensorType.Shape.Rank; axis++) { - if (tensorType.Shape[axis] is { IsFixed: true, Value: int s } && placement.Hierarchy[i] > 1 && IsDivideExactly(s, placement.Hierarchy[i])) + if (tensorType.Shape[axis] is { IsFixed: true, FixedValue: long s } && placement.Hierarchy[i] > 1 && IsDivideExactly(s, placement.Hierarchy[i])) { ndsbp.Add(SBP.S(axis)); } @@ -72,7 +72,7 @@ public static IReadOnlyList> GetPartialCandidateNDSBPs(DistributedT public static bool IsDistributable(TensorType tensorType, ReadOnlySpan ndsbp, Placement placement) { - if (!tensorType.Shape.IsFixed) + if (!tensorType.Shape.IsRanked) { return false; } @@ -83,8 +83,8 @@ public static bool IsDistributable(TensorType tensorType, ReadOnlySpan ndsb public static IReadOnlyList GetDivisors(DistributedType distributedType) { - var shape = distributedType.TensorType.Shape.ToValueArray(); - var divisors = Enumerable.Repeat(0, shape.Length).ToArray(); + var rank = distributedType.TensorType.Shape.Rank; + var divisors = Enumerable.Repeat(0, rank).ToArray(); for (int i = 0; i < distributedType.NdSBP.Count; i++) { if (distributedType.NdSBP[i] is SBPSplit { Axis: int axis }) @@ -154,7 +154,7 @@ public static Expr[] TryGetNonUniformDividedShape(DistributedType distributedTyp }).ToArray(); } - public static List TryGetNonUniformDividedSlice(DistributedType distributedType) + public static List TryGetNonUniformDividedSlice(DistributedType distributedType) { var shape = distributedType.TensorType.Shape.ToValueArray(); var hierarchies = Enumerable.Range(0, shape.Length).Select(i => new List()).ToArray(); @@ -166,9 +166,9 @@ public static List TryGetNonUniformDividedSlice(DistributedType distribut } } - var spliList = hierarchies.Select, int[]>((divs, axis) => + var spliList = hierarchies.Select, long[]>((divs, axis) => { - int[] dim; + long[] dim; if (divs.Any()) { var divsor = (int)TensorUtilities.GetProduct(divs.Select(h => distributedType.Placement.Hierarchy[h]).ToArray()); @@ -188,8 +188,8 @@ public static List TryGetNonUniformDividedSlice(DistributedType distribut return dim; }).ToList(); - IEnumerable ret = new[] { Array.Empty() }; - foreach (int[] array in spliList) + IEnumerable ret = new[] { Array.Empty() }; + foreach (long[] array in spliList) { ret = from seq in ret from item in array @@ -199,7 +199,7 @@ from item in array return ret.ToList(); } - public static bool IsDivideBy(int input, int divisor) + public static bool IsDivideBy(long input, int divisor) { if (input >= divisor) { @@ -209,7 +209,7 @@ public static bool IsDivideBy(int input, int divisor) return false; } - public static bool IsDivideExactly(int input, int divisor) + public static bool IsDivideExactly(long input, int divisor) { if (input >= divisor && input % divisor == 0) { @@ -227,11 +227,11 @@ public static float GetDividedTensorEfficiency(DistributedType distributedType, return 1f; } - return Enumerable.Range(0, tiles.Count).Select(i => tiles[i].Ranges(0, shape[i])).CartesianProduct().Select(rgs => + return Enumerable.Range(0, tiles.Count).Select(i => ((int)tiles[i].FixedValue).Ranges(0, (int)shape[i].FixedValue)).CartesianProduct().Select(rgs => { var slice = rgs.ToArray(); - var iscontiguous = TensorUtilities.IsContiguousSlice(shape.ToArray(), slice, out var contiguousStart); - var size = TensorUtilities.GetProduct(tiles.ToArray(), contiguousStart) * distributedType.TensorType.DType.SizeInBytes; + var iscontiguous = TensorUtilities.IsContiguousSlice(shape.ToValueArray(), slice, out var contiguousStart); + var size = TensorUtilities.GetProduct(tiles.ToValueArray(), contiguousStart) * distributedType.TensorType.DType.SizeInBytes; var (div, rem) = Math.DivRem(size, burstLength); return ((div * 1.0f) + ((float)rem / burstLength)) / (div + 1); }).Average(); @@ -240,7 +240,7 @@ public static float GetDividedTensorEfficiency(DistributedType distributedType, public static TensorType GetDividedTensorType(DistributedType distributedType) { var (tiles, _) = GetDividedTile(distributedType); - return distributedType.TensorType with { Shape = new Shape(tiles) }; + return distributedType.TensorType with { Shape = tiles }; } public static int[] GetUnraveledIndex(int index, int[] hierarchies) @@ -257,11 +257,11 @@ public static int[] GetUnraveledIndex(int index, int[] hierarchies) return unraveledIndex; } - public static (int[] Offset, int[] Shape) GetLocalOffsetAndShape(DistributedType distributedType, int[] shardIndex) + public static (long[] Offset, long[] Shape) GetLocalOffsetAndShape(DistributedType distributedType, int[] shardIndex) { var globalShape = distributedType.TensorType.Shape.ToValueArray(); - var offset = new int[distributedType.TensorType.Shape.Rank]; - var shape = new int[distributedType.TensorType.Shape.Rank]; + var offset = new long[distributedType.TensorType.Shape.Rank]; + var shape = new long[distributedType.TensorType.Shape.Rank]; for (int axis = 0; axis < offset.Length; axis++) { var splits = (from d in distributedType.NdSBP.Select((s, i) => (s, i)) @@ -289,10 +289,10 @@ public static (int[] Offset, int[] Shape) GetLocalOffsetAndShape(DistributedType return (offset, shape); } - private static (IReadOnlyList Tile, IReadOnlyList Shape) GetDividedTile(DistributedType distributedType) + private static (Shape Tile, Shape Shape) GetDividedTile(DistributedType distributedType) { - var shape = distributedType.TensorType.Shape.ToValueArray(); - var tiles = distributedType.TensorType.Shape.ToValueArray(); + var shape = distributedType.TensorType.Shape.ToArray(); + var tiles = distributedType.TensorType.Shape.ToArray(); foreach (var (s, i) in distributedType.NdSBP.Select((s, i) => (s, i)).Where(t => t.s is SBPSplit).Select(t => ((SBPSplit)t.s, t.i))) { tiles[s.Axis] /= distributedType.Placement.Hierarchy[i]; diff --git a/src/Nncase.Core/Utilities/DumpUtility.cs b/src/Nncase.Core/Utilities/DumpUtility.cs index 24c863fb7..d1f9fb932 100644 --- a/src/Nncase.Core/Utilities/DumpUtility.cs +++ b/src/Nncase.Core/Utilities/DumpUtility.cs @@ -223,7 +223,7 @@ public static Tensor ReadBinFile(string path, DataType dt, Shape shape) using (var stream = new FileStream(Path.Join(path), FileMode.Open, FileAccess.Read, FileShare.None)) using (var reader = new BinaryReader(stream)) { - var bytes = reader.ReadBytes(shape.Prod().FixedValue * dt.SizeInBytes); + var bytes = reader.ReadBytes(checked((int)shape.Prod().FixedValue * dt.SizeInBytes)); return Tensor.FromBytes(dt, bytes, shape); } } diff --git a/src/Nncase.Diagnostics/Diagnostics/ILDotPrintVisitor.cs b/src/Nncase.Diagnostics/Diagnostics/ILDotPrintVisitor.cs index ca0880da7..73bbb0666 100644 --- a/src/Nncase.Diagnostics/Diagnostics/ILDotPrintVisitor.cs +++ b/src/Nncase.Diagnostics/Diagnostics/ILDotPrintVisitor.cs @@ -22,6 +22,8 @@ using GiGraph.Dot.Types.Styling; using NetFabric.Hyperlinq; using Nncase.IR; +using Nncase.IR.Affine; +using Nncase.IR.Buffers; namespace Nncase.Diagnostics; @@ -160,6 +162,95 @@ protected override ILDotOption VisitFunction(Function expr) return new(expr.Name); } + protected override ILDotOption VisitBufferOf(BufferOf expr) + { + if (!_exprMemo.TryGetValue(expr, out var result)) + { + var id = _idCounter++; + string exprId = "\"" + id.ToString() + "\""; + + var table = new DotHtmlTable + { + BorderWidth = 0, + CellBorderWidth = 1, + CellSpacing = 0, + }; + + var connect_list = new List<(Expr, string)>(); + + // 1. the connect type. + table.AddRow(row => + { + row.AddCell("BufferOf"); // key wrods type. + row.AddCell(Visit(expr.Input).Str); // target. + }); + + // 3. make crrent node. + var dotNode = _dotGraph.Nodes.Add(exprId); + dotNode.ToPlainHtmlNode(table); + + // 4. connect edge. + // _dotGraph.Edges.Add(Visit(expr.Input).DotNode, dotNode); + result = new(dotNode); + _exprMemo.Add(expr, result); + } + + return result; + } + + protected override ILDotOption VisitGrid(Grid expr) + { + if (!_exprMemo.TryGetValue(expr, out var result)) + { + var id = _idCounter++; + string exprId = "\"" + id.ToString() + "\""; + + var table = new DotHtmlTable + { + BorderWidth = 0, + CellBorderWidth = 1, + CellSpacing = 0, + }; + + var connect_list = new List<(Expr, string)>(); + + // 1. the connect type. + table.AddRow(row => + { + row.AddCell("Grid"); // key wrods type. + int count = 0; + foreach (var child in expr.Buffers) + { + var childnode = Visit(child); + var portName = $"P{count++}"; + row.AddCell(childnode.IsDotNode ? string.Empty : childnode.Str, cell => cell.PortName = portName); + if (childnode.IsDotNode) + { + connect_list.Add((child, portName)); + } + } + }); + + // 3. make crrent node. + var dotNode = _dotGraph.Nodes.Add(exprId); + dotNode.ToPlainHtmlNode(table); + + // 4. connect edge. + foreach (var (child, port_name) in connect_list) + { + _dotGraph.Edges.Add(Visit(child).DotNode, dotNode, edge => + { + edge.Head.Endpoint.Port = new DotEndpointPort(port_name); + }); + } + + result = new(dotNode); + _exprMemo.Add(expr, result); + } + + return result; + } + protected override ILDotOption VisitTuple(IR.Tuple expr) { if (!_exprMemo.TryGetValue(expr, out var result)) diff --git a/src/Nncase.Diagnostics/Diagnostics/ILPrintVisitor.cs b/src/Nncase.Diagnostics/Diagnostics/ILPrintVisitor.cs index 00e740e1c..7084b3fbe 100644 --- a/src/Nncase.Diagnostics/Diagnostics/ILPrintVisitor.cs +++ b/src/Nncase.Diagnostics/Diagnostics/ILPrintVisitor.cs @@ -267,10 +267,10 @@ public override string VisitType(CallableType type) => /// public override string VisitType(TensorType type) => type.DType switch { - PrimType ptype => ptype.GetDisplayName() + (type.Shape.IsScalar ? string.Empty : type.Shape.ToString()), + PrimType ptype => ptype.GetDisplayName() + (type.Shape.IsScalar ? string.Empty : VisitShape(type.Shape)), PointerType { ElemType: PrimType etype } => $"*{etype.GetDisplayName()}", ValueType => $"{type.DType}", - VectorType vtype => $"{vtype.ElemType.GetDisplayName()}<{string.Join(",", vtype.Lanes)}>" + (type.Shape.IsScalar ? string.Empty : type.Shape.ToString()), + VectorType vtype => $"{vtype.ElemType.GetDisplayName()}<{string.Join(",", vtype.Lanes)}>" + (type.Shape.IsScalar ? string.Empty : VisitShape(type.Shape)), _ => throw new NotSupportedException(type.DType.GetType().Name), }; @@ -706,6 +706,23 @@ private string AllocateTempVar(Expr expr) return name; } + private string VisitShape(Shape shape) => + shape.Kind switch + { + ShapeKind.Invalid => "Invalid", + ShapeKind.Unranked => "Unranked", + _ => $"[{string.Join(',', shape.Select(VisitDimension))}]", + }; + + private string VisitDimension(Dimension dimension) => + dimension.Kind switch + { + DimensionKind.Any => "any", + DimensionKind.Fixed => dimension.FixedValue.ToString(), + DimensionKind.Unknown => dimension.Value is Var var ? $"%{var.Name}" : "?", + _ => throw new NotSupportedException(dimension.Kind.ToString()), + }; + private void AppendCheckedType(IRType? type, string end = "", bool hasNewLine = true) { if (type is not null) diff --git a/src/Nncase.EGraph/CostModel/EGraphCostModel.cs b/src/Nncase.EGraph/CostModel/EGraphCostModel.cs index 221deeb35..511873b69 100644 --- a/src/Nncase.EGraph/CostModel/EGraphCostModel.cs +++ b/src/Nncase.EGraph/CostModel/EGraphCostModel.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Text; using System.Threading.Tasks; @@ -20,4 +21,9 @@ public EGraphCostModel(IReadOnlyDictionary costs) } public Cost this[ENode enode] => _costs[enode]; + + public bool TryGet(ENode node, [MaybeNullWhen(false)] out Cost cost) + { + return _costs.TryGetValue(node, out cost); + } } diff --git a/src/Nncase.EGraph/CostModel/EGraphSatPrinter.cs b/src/Nncase.EGraph/CostModel/EGraphSatPrinter.cs index 32cc8ae17..71c8f72e3 100644 --- a/src/Nncase.EGraph/CostModel/EGraphSatPrinter.cs +++ b/src/Nncase.EGraph/CostModel/EGraphSatPrinter.cs @@ -37,18 +37,20 @@ private DotGraph AttachEGraphCostPick(CostModel.EGraphCostModel costModel, IRead // 1. display each enode costs. foreach (var (enode, (dotnode, table)) in NodesMap) { - var cost = costModel[enode]; - if (cost != CostModel.Cost.Zero) + if (costModel.TryGet(enode, out var cost)) { - table.AddRow(row => + if (cost != CostModel.Cost.Zero) { - foreach (var (k, v) in cost.Factors) + table.AddRow(row => { - row.AddCell($"{k}: {v:F2}"); - } + foreach (var (k, v) in cost.Factors) + { + row.AddCell($"{k}: {v:F2}"); + } - row.AddCell($"Score: {cost.Score:F2}"); - }); + row.AddCell($"Score: {cost.Score:F2}"); + }); + } } dotnode.ToPlainHtmlNode(table); diff --git a/src/Nncase.EGraph/Passes/EGraphExtractor.cs b/src/Nncase.EGraph/Passes/EGraphExtractor.cs index 9bfeb504f..cbe424919 100644 --- a/src/Nncase.EGraph/Passes/EGraphExtractor.cs +++ b/src/Nncase.EGraph/Passes/EGraphExtractor.cs @@ -27,6 +27,7 @@ public EGraphExtractor(EGraphCostModel costModel) public Expr Extract(EClass root, IEGraph eGraph, EGraphExtractConstrains[] constrains) { var cpmodel = new CpModel(); + var nodes = CollectNodes(root); // 0. create bool var for all enode. var varMemo = new Dictionary(); @@ -34,7 +35,10 @@ public Expr Extract(EClass root, IEGraph eGraph, EGraphExtractConstrains[] const { foreach (var (e, i) in cls.Nodes.Select((e, i) => (e, i))) { - varMemo.Add(e, cpmodel.NewBoolVar($"{cls.Id}_{i}")); + if (nodes.Contains(e)) + { + varMemo.Add(e, cpmodel.NewBoolVar($"{cls.Id}_{i}")); + } } } @@ -42,7 +46,7 @@ public Expr Extract(EClass root, IEGraph eGraph, EGraphExtractConstrains[] const cpmodel.AddBoolOr(root.Nodes.Select(n => varMemo[n]).ToArray()); // 2. when pick node, must pick one child node. - foreach (var n in eGraph.Nodes) + foreach (var n in nodes) { var ns = new[] { varMemo[n].Not() }; foreach (var child in n.Children) @@ -103,7 +107,7 @@ public Expr Extract(EClass root, IEGraph eGraph, EGraphExtractConstrains[] const } // 3. add pick weights for all enode. - cpmodel.Minimize(LinearExpr.WeightedSum(eGraph.Nodes.Select(n => varMemo[n]), eGraph.Nodes.Select(n => checked((long)_costModel[n].Score)))); + cpmodel.Minimize(LinearExpr.WeightedSum(nodes.Select(n => varMemo[n]), nodes.Select(n => checked((long)_costModel[n].Score)))); if (cpmodel.Validate().Any()) { @@ -155,7 +159,7 @@ public Expr Extract(EClass root, IEGraph eGraph, EGraphExtractConstrains[] const throw new InvalidProgramException("SatExtract Failed!"); } - var picks = eGraph.Nodes.ToDictionary(e => e, e => solver.BooleanValue(varMemo[e])); + var picks = nodes.ToDictionary(e => e, e => solver.BooleanValue(varMemo[e])); using (var dumpStream = enableDump ? DumpScope.Current.OpenFile("Costs/Pick.dot") : Stream.Null) { EGraphPrinter.DumpEgraphAsDot(eGraph, _costModel, picks, root.Find(), dumpStream); @@ -506,6 +510,31 @@ void StrongConnect(int v) return (components, sccAdjList); } + + private HashSet CollectNodes(EClass root) + { + var visited = new HashSet(); + void Visit(ENode node) + { + if (visited.Add(node)) + { + foreach (var child in node.Children) + { + foreach (var n in child.Nodes) + { + Visit(n); + } + } + } + } + + foreach (var n in root.Nodes) + { + Visit(n); + } + + return visited; + } } internal sealed class PrintCostCallBack : CpSolverSolutionCallback diff --git a/src/Nncase.Evaluator/Buffers/BufferSubview.cs b/src/Nncase.Evaluator/Buffers/BufferSubview.cs index 1864b0839..69711208b 100644 --- a/src/Nncase.Evaluator/Buffers/BufferSubview.cs +++ b/src/Nncase.Evaluator/Buffers/BufferSubview.cs @@ -21,8 +21,8 @@ public IRType Visit(ITypeInferenceContext context, BufferSubview target) var shapeExpr = context.GetArgument(target, BufferSubview.Shape); var shape = shapeExpr switch { - IR.Tuple t => new Shape(t.Fields.AsValueEnumerable().Select(d => d is TensorConst tc ? new Dimension(tc.Value.ToScalar()) : Dimension.Unknown).ToArray()), - TupleConst tc => new Shape(tc.Value.Select(d => d is Tensor t ? new Dimension(t.ToScalar()) : Dimension.Unknown)), + IR.Tuple t => new Shape(t.Fields), + TupleConst tc => new Shape(tc.Value.AsTensor().ToArray()), _ => throw new ArgumentException("Invalid shape argument."), }; return new TensorType(buffer.CheckedDataType, shape); diff --git a/src/Nncase.Evaluator/Buffers/Uninitialized.cs b/src/Nncase.Evaluator/Buffers/Uninitialized.cs index 1a6d36577..5dbd8b175 100644 --- a/src/Nncase.Evaluator/Buffers/Uninitialized.cs +++ b/src/Nncase.Evaluator/Buffers/Uninitialized.cs @@ -24,8 +24,8 @@ public IRType Visit(ITypeInferenceContext context, Uninitialized target) } else { - var shape = context.CheckArgumentType(target, Uninitialized.Shape); - tensorType = new TensorType(target.DType, new(Enumerable.Repeat(Dimension.Unknown, shape.Shape[0].FixedValue))); + var shape = context.GetArgument(target, Uninitialized.Shape); + tensorType = new TensorType(target.DType, Shape.FromExpr(shape)); } return target.Placement.Rank == 0 ? tensorType : new DistributedType(tensorType, target.NdSBP, target.Placement); diff --git a/src/Nncase.Evaluator/Extension/OrtKIExtensions.cs b/src/Nncase.Evaluator/Extension/OrtKIExtensions.cs index 35291304d..7920c436a 100644 --- a/src/Nncase.Evaluator/Extension/OrtKIExtensions.cs +++ b/src/Nncase.Evaluator/Extension/OrtKIExtensions.cs @@ -53,12 +53,12 @@ public static class OrtKIExtensions public static Tensor ToTensor(this OrtKISharp.Tensor tensor) { - return Tensor.From(tensor.DataType.ToDataType(), new TensorInitializerWithOrt(tensor), tensor.Shape.ToInts()); + return Tensor.From(tensor.DataType.ToDataType(), new TensorInitializerWithOrt(tensor), tensor.Shape); } public static Tensor ToTensor(this OrtKISharp.Tensor tensor, TensorType tensorType) { - return Tensor.From(tensorType.DType, new TensorInitializerWithOrt(tensor), tensorType.Shape.IsFixed ? tensorType.Shape : tensor.Shape.ToInts()); + return Tensor.From(tensorType.DType, new TensorInitializerWithOrt(tensor), tensorType.Shape.IsFixed ? tensorType.Shape : tensor.Shape); } public static TensorValue ToValue(this OrtKISharp.Tensor tensor) @@ -68,8 +68,8 @@ public static TensorValue ToValue(this OrtKISharp.Tensor tensor) public static OrtKISharp.Tensor ToOrtTensor(this Tensor tensor) => tensor.ElementType switch { - VectorType vectorType => ToOrtTensor(tensor, vectorType.ElemType.ToOrtType(), tensor.Dimensions.ToArray().Concat(vectorType.Lanes.ToArray()).ToArray()), - PrimType primType => ToOrtTensor(tensor, primType.ToOrtType(), tensor.Dimensions.ToArray()), + VectorType vectorType => ToOrtTensor(tensor, vectorType.ElemType.ToOrtType(), tensor.Dimensions.ToInts().Concat(vectorType.Lanes).ToArray()), + PrimType primType => ToOrtTensor(tensor, primType.ToOrtType(), tensor.Dimensions.ToInts()), _ => throw new NotSupportedException(), }; diff --git a/src/Nncase.Evaluator/Math/MatMul.cs b/src/Nncase.Evaluator/Math/MatMul.cs index 601e5205a..817714bde 100644 --- a/src/Nncase.Evaluator/Math/MatMul.cs +++ b/src/Nncase.Evaluator/Math/MatMul.cs @@ -150,7 +150,7 @@ public static IRType VisitTensorType(TensorType lhs, TensorType rhs, bool packin return new InvalidType("MatMul lhs and rhs have different DType"); } - if (lhs.Shape[lk] != rhs.Shape[rk] && lhs.Shape[lk] != Dimension.Unknown && rhs.Shape[rk] != Dimension.Unknown) + if (lhs.Shape[lk] != rhs.Shape[rk] && lhs.Shape[lk].IsFixed && rhs.Shape[rk].IsFixed) { return new InvalidType("MatMul lhs and rhs have not compatiable shape"); } @@ -214,8 +214,8 @@ public static IRType VisitTensorType(TensorType lhs, TensorType rhs, bool packin var rhsShape = lhs.Shape.Rank <= rhs.Shape.Rank ? rhs.Shape.ToArray() : Enumerable.Repeat((Dimension)1, lhs.Shape.Rank - rhs.Shape.Rank).Concat(rhs.Shape).ToArray(); var bigShape = Enumerable.Zip(lhsShape, rhsShape).SkipLast(2).Select(t => - t.First == Dimension.Unknown || t.Second == Dimension.Unknown - ? Dimension.Unknown + t.First.IsUnknown || t.Second.IsUnknown + ? (Dimension)IR.F.Math.Max(t.First.Value, t.Second.Value) : System.Math.Max(t.First.FixedValue, t.Second.FixedValue)).ToArray(); // batch and channel diff --git a/src/Nncase.Evaluator/NN/BatchToSpace.cs b/src/Nncase.Evaluator/NN/BatchToSpace.cs index 8da179f45..6a5240926 100644 --- a/src/Nncase.Evaluator/NN/BatchToSpace.cs +++ b/src/Nncase.Evaluator/NN/BatchToSpace.cs @@ -39,7 +39,7 @@ public IValue Visit(IEvaluateContext context, BatchToSpace s) var targetSpatial = ZipExec(spatial, blockShape, (x, y) => x * y); var ccat1 = spatial.Concat(blockShape).ToArray(); - var re1 = Tensor.From(ccat1, new[] { ccat1.Length / blockLen, blockLen }); + var re1 = Tensor.From(ccat1, [ccat1.Length / blockLen, blockLen]); var interLeave = OrtKI.Transpose(re1.ToOrtTensor(), new long[] { 1, 0 }).ToArray(); var shape1 = new int[] { -1 }.Concat(interLeave).Concat(depth).ToArray(); @@ -125,7 +125,7 @@ public Expr Visit(IShapeEvaluateContext context, BatchToSpace target) var blockSize = Prod(blockShape); var batch = inShape[0]; var d0 = batch / blockSize; - var m = blockShape.CheckedShape[0].FixedValue; + var m = (int)blockShape.CheckedShape[0].FixedValue; var cropSection = Enumerable.Range(0, m).Select( i => (inShape[i + 1] * blockShape[0]) - crops[i, 0] - crops[i, 1]).ToArray(); @@ -200,7 +200,7 @@ private IRType Visit(ITypeInferenceContext context, BatchToSpace target, TensorT var blockSize = blockShapeArr.Aggregate(1, (a, b) => a * b); var d0 = batch / blockSize; Trace.Assert(blockShape.Shape[0] == crops.Shape[0]); - var m = blockShape.Shape[0].FixedValue; + var m = (int)blockShape.Shape[0].FixedValue; var cropsV = cropsValue.Value.Cast(); var cropSection = Enumerable.Range(0, m).Select( i => (inShape[i + 1] * blockShapeArr[i]) - cropsV[i, 0] - cropsV[i, 1]); @@ -216,7 +216,8 @@ private IRType Visit(ITypeInferenceContext context, BatchToSpace target, TensorT } else { - return new TensorType(input.DType, Enumerable.Repeat(Dimension.Unknown, input.Shape.Count).ToArray()); + // return new TensorType(input.DType, Enumerable.Repeat(Dimension.Unknown, input.Shape.Count).ToArray()); + throw new NotImplementedException(); } } } diff --git a/src/Nncase.Evaluator/NN/Conv2DTranspose.cs b/src/Nncase.Evaluator/NN/Conv2DTranspose.cs index 56ae68127..b17a78586 100644 --- a/src/Nncase.Evaluator/NN/Conv2DTranspose.cs +++ b/src/Nncase.Evaluator/NN/Conv2DTranspose.cs @@ -111,7 +111,7 @@ public IValue Visit(IEvaluateContext context, Conv2DTranspose conv) outCache[i] = outCache[i] + biasArray[biasIdx]; } - return new TensorValue(Tensor.From(outCache, new[] { (int)outputShape[0], (int)outputShape[1], (int)outputShape[2], (int)outputShape[3] })); + return new TensorValue(Tensor.From(outCache, [outputShape[0], outputShape[1], outputShape[2], outputShape[3]])); } /// diff --git a/src/Nncase.Evaluator/NN/LayerNorm.cs b/src/Nncase.Evaluator/NN/LayerNorm.cs index a0442a039..8ae72fd24 100644 --- a/src/Nncase.Evaluator/NN/LayerNorm.cs +++ b/src/Nncase.Evaluator/NN/LayerNorm.cs @@ -99,7 +99,7 @@ public IValue Visit(IEvaluateContext context, LayerNorm layerNorm) // return Value.FromTensor(OrtKI.LayerNormalization(input, scale, bias, layerNorm.Axis, layerNorm.Epsilon, 1)); var shape = input.Shape.ToValueArray(); - var output = LayerNormImpl(shape, input.Buffer.Span, scale.Buffer.Span, bias.Buffer.Span, layerNorm.Axis, layerNorm.Epsilon, layerNorm.UseMean); + var output = LayerNormImpl(shape.ToInts(), input.Buffer.Span, scale.Buffer.Span, bias.Buffer.Span, layerNorm.Axis, layerNorm.Epsilon, layerNorm.UseMean); return Value.FromTensor(Tensor.From(output, shape)); } diff --git a/src/Nncase.Evaluator/NN/SpaceToBatch.cs b/src/Nncase.Evaluator/NN/SpaceToBatch.cs index 89c574fe4..2b9ed1527 100644 --- a/src/Nncase.Evaluator/NN/SpaceToBatch.cs +++ b/src/Nncase.Evaluator/NN/SpaceToBatch.cs @@ -247,7 +247,7 @@ private IRType Visit(ITypeInferenceContext context, SpaceToBatch target, TensorT var outshape = new List { padded_shape[0] }; foreach (var i in Enumerable.Range(1, m)) { - outshape.Add(padded_shape[i].IsUnknown ? Dimension.Unknown : + outshape.Add(padded_shape[i].IsUnknown ? padded_shape[i] / ts_block_shape[i - 1] : padded_shape[i].FixedValue % ts_block_shape[i - 1] == 0 ? padded_shape[i].FixedValue / ts_block_shape[i - 1] : throw new TypeInferenceInterruptException( @@ -261,7 +261,7 @@ private IRType Visit(ITypeInferenceContext context, SpaceToBatch target, TensorT foreach (var block in ts_block_shape) { - outshape[0] = outshape[0].IsUnknown ? Dimension.Unknown : outshape[0].FixedValue * block; + outshape[0] = outshape[0].FixedValue * block; } var outputShape = ShapeNHWCToNCHW(inShape, outshape); @@ -269,6 +269,7 @@ private IRType Visit(ITypeInferenceContext context, SpaceToBatch target, TensorT return input with { Shape = new Shape(outputShape) }; } - return new TensorType(input.DType, Enumerable.Repeat(Dimension.Unknown, input.Shape.Count).ToArray()); + // return new TensorType(input.DType, Enumerable.Repeat(Dimension.Unknown, input.Shape.Count).ToArray()); + throw new NotImplementedException(); } } diff --git a/src/Nncase.Evaluator/RNN/LSTM.cs b/src/Nncase.Evaluator/RNN/LSTM.cs index fc1d1f3da..83ddda862 100644 --- a/src/Nncase.Evaluator/RNN/LSTM.cs +++ b/src/Nncase.Evaluator/RNN/LSTM.cs @@ -134,12 +134,7 @@ private TensorType InferYType(ITypeInferenceContext context, LSTM target, Tensor // [batch_size, seq_length, num_directions, hidden_size] var yShape = x.Shape.ToList(); yShape.Insert(seqLenIndex + 1, numDirections); - var hiddenSize = Dimension.Unknown; - if (context.GetArgument(target, LSTM.HiddenSize) is TensorConst hiddenSizeConst) - { - hiddenSize = hiddenSizeConst.Value.ToScalar(); - } - + var hiddenSize = context.GetArgument(target, LSTM.HiddenSize); yShape[^1] = hiddenSize; return x with { Shape = yShape.ToArray() }; } diff --git a/src/Nncase.Evaluator/ShapeExpr/BroadcastShape.cs b/src/Nncase.Evaluator/ShapeExpr/BroadcastShape.cs index c7cb702bc..f1f084119 100644 --- a/src/Nncase.Evaluator/ShapeExpr/BroadcastShape.cs +++ b/src/Nncase.Evaluator/ShapeExpr/BroadcastShape.cs @@ -30,7 +30,8 @@ public IValue Visit(IEvaluateContext context, BroadcastShape broadcastShape) public IRType Visit(ITypeInferenceContext context, BroadcastShape target) { - return new TensorType(DataTypes.Int64, new[] { Dimension.Unknown }); + // return new TensorType(DataTypes.Int64, new[] { Dimension.Unknown }); + throw new NotImplementedException(); } public Cost Visit(ICostEvaluateContext context, BroadcastShape target) diff --git a/src/Nncase.Evaluator/ShapeExpr/SqueezeShape.cs b/src/Nncase.Evaluator/ShapeExpr/SqueezeShape.cs index df1c668bd..a84c2b50a 100644 --- a/src/Nncase.Evaluator/ShapeExpr/SqueezeShape.cs +++ b/src/Nncase.Evaluator/ShapeExpr/SqueezeShape.cs @@ -28,7 +28,8 @@ public IRType Visit(ITypeInferenceContext context, SqueezeShape target) var dims = context.CheckArgumentType(target, SqueezeShape.Dim); if (!input.CheckedShape.IsFixed) { - return new TensorType(DataTypes.Int64, new[] { Dimension.Unknown }); + // return new TensorType(DataTypes.Int64, new[] { Dimension.Unknown }); + throw new NotImplementedException(); } return new TensorType(DataTypes.Int64, new[] { input.CheckedShape.Size - dims.Shape[0] }); diff --git a/src/Nncase.Evaluator/ShapeExpr/UnsqueezeShape.cs b/src/Nncase.Evaluator/ShapeExpr/UnsqueezeShape.cs index c195b6c2d..edb188d2f 100644 --- a/src/Nncase.Evaluator/ShapeExpr/UnsqueezeShape.cs +++ b/src/Nncase.Evaluator/ShapeExpr/UnsqueezeShape.cs @@ -28,7 +28,8 @@ public IRType Visit(ITypeInferenceContext context, UnsqueezeShape target) var dims = context.CheckArgumentType(target, UnsqueezeShape.Dim); if (!input.Shape.IsFixed) { - return new TensorType(DataTypes.Int64, new[] { Dimension.Unknown }); + // return new TensorType(DataTypes.Int64, new[] { Dimension.Unknown }); + throw new NotImplementedException(); } return new TensorType(DataTypes.Int64, new[] { input.Shape.Size + dims.Shape[0] }); diff --git a/src/Nncase.Evaluator/Tensors/Broadcast.cs b/src/Nncase.Evaluator/Tensors/Broadcast.cs index 53b1c697f..d160503e7 100644 --- a/src/Nncase.Evaluator/Tensors/Broadcast.cs +++ b/src/Nncase.Evaluator/Tensors/Broadcast.cs @@ -50,16 +50,6 @@ public Metric Visit(IMetricEvaluateContext context, Broadcast target) private IRType Visit(TensorType input, TensorType shape, ITypeInferenceContext context, Broadcast op) { var shapeValue = context.GetArgument(op, Broadcast.Shape); - if (shapeValue is TensorConst constShapeValue && input.Shape.IsFixed) - { - return TypeInference.BroadcastType(input, new TensorType(input.DType, constShapeValue.Value.ToArray())); - } - - if (shape.Shape[0].IsFixed) - { - return input with { Shape = Enumerable.Repeat(Dimension.Unknown, shape.Shape[0].FixedValue).ToArray() }; - } - - return input with { Shape = IR.Shape.Unranked }; + return TypeInference.BroadcastType(input, new TensorType(input.DType, Shape.FromExpr(shapeValue))); } } diff --git a/src/Nncase.Evaluator/Tensors/Concat.cs b/src/Nncase.Evaluator/Tensors/Concat.cs index 942c821bd..0a3039710 100644 --- a/src/Nncase.Evaluator/Tensors/Concat.cs +++ b/src/Nncase.Evaluator/Tensors/Concat.cs @@ -147,12 +147,7 @@ private IRType Visit(TupleType inputs, int axis) } var d = GetTensorType(inType).Shape[i]; - if (d.IsUnknown) - { - return Dimension.Unknown; - } - - if (d.FixedValue != GetTensorType(inputs[0]).Shape[i]) + if (d.IsFixed && d.FixedValue != GetTensorType(inputs[0]).Shape[i]) { allAxisDimIsSame = false; } @@ -165,7 +160,7 @@ private IRType Visit(TupleType inputs, int axis) else { invalidType = new InvalidType("Concat dims that except the shape of axis dim are different"); - return Dimension.Unknown; + return -1; } } }); @@ -227,18 +222,8 @@ private IRType Visit(TupleType inputs, int axis) // else get sum of dims private Dimension AxisDim(TupleType inputs, int axisValue) { - var allAxisDimIsFixed = inputs.Fields.Aggregate( - true, - (prod, next) => prod && (next switch { TensorType t => t, DistributedType d => d.TensorType, _ => throw new NotSupportedException() }).Shape[axisValue].IsFixed); - if (allAxisDimIsFixed) - { - return inputs.Fields.Aggregate( - 0, - (prod, next) => prod + (next switch { TensorType t => t, DistributedType d => d.TensorType, _ => throw new NotSupportedException() }).Shape[axisValue].FixedValue); - } - else - { - return Dimension.Unknown; - } + return inputs.Fields.Aggregate( + (Dimension)0, + (prod, next) => prod + (next switch { TensorType t => t, DistributedType d => d.TensorType, _ => throw new NotSupportedException() }).Shape[axisValue].Value); } } diff --git a/src/Nncase.Evaluator/Tensors/ConstantOfShape.cs b/src/Nncase.Evaluator/Tensors/ConstantOfShape.cs index 163f07ad9..e4532ccc4 100644 --- a/src/Nncase.Evaluator/Tensors/ConstantOfShape.cs +++ b/src/Nncase.Evaluator/Tensors/ConstantOfShape.cs @@ -18,9 +18,9 @@ public class ConstantOfShapeEvaluator : IEvaluator, ITypeInfere /// public IValue Visit(IEvaluateContext context, ConstantOfShape target) { - var shape = context.GetArgumentValueAsArray(target, ConstantOfShape.Shape); + var shape = context.GetArgumentValueAsArray(target, ConstantOfShape.Shape); var value = context.GetArgumentValueAsTensor(target, ConstantOfShape.Value); - var result = Enumerable.Repeat(value.ToScalar(), shape.Aggregate(1, (i, i1) => i * i1)).ToArray(); + var result = Enumerable.Repeat(value.ToScalar(), shape.Aggregate(1, (i, i1) => i * (int)i1)).ToArray(); return OrtKI.Cast(Tensor.From(result, shape).ToOrtTensor(), (int)value.ElementType.ToOrtType()).ToValue(); } @@ -28,17 +28,9 @@ public IValue Visit(IEvaluateContext context, ConstantOfShape target) public IRType Visit(ITypeInferenceContext context, ConstantOfShape target) { var value = context.CheckArgumentType(target, ConstantOfShape.Value); - var shape = context.CheckArgumentType(target, ConstantOfShape.Shape); + var shape = context.GetArgument(target, ConstantOfShape.Shape); var type = value.DType; - if (context.GetArgument(target, ConstantOfShape.Shape) is TensorConst shapeValue) - { - return new TensorType(type, shapeValue.Value.ToArray()); - } - else - { - var outShape = TypeInference.ReshapeTo(shape); - return new TensorType(type, outShape); - } + return new TensorType(type, Shape.FromExpr(shape)); } public Cost Visit(ICostEvaluateContext context, ConstantOfShape target) diff --git a/src/Nncase.Evaluator/Tensors/Expand.cs b/src/Nncase.Evaluator/Tensors/Expand.cs index e06241f90..a13e7f83e 100644 --- a/src/Nncase.Evaluator/Tensors/Expand.cs +++ b/src/Nncase.Evaluator/Tensors/Expand.cs @@ -56,7 +56,7 @@ public Metric Visit(IMetricEvaluateContext context, Expand target) public IRType Visit(ITypeInferenceContext context, Expand target) { var input = context.CheckArgumentType(target, Expand.Input); - var shape = context.CheckArgumentType(target, Expand.Shape); + var shape = context.CheckArgumentTensorTypeOrBroadcast(target, Expand.Shape); return input switch { TensorType t => Visit(context, target, t, shape), @@ -67,15 +67,8 @@ public IRType Visit(ITypeInferenceContext context, Expand target) private IRType Visit(ITypeInferenceContext context, Expand target, TensorType input, TensorType shape) { - var shape_expr = context.GetArgument(target, Expand.Shape); - if (shape_expr is TensorConst constShape) - { - return input with { Shape = new Shape(constShape.Value.Cast()) }; - } - else - { - return input with { Shape = TypeInference.ReshapeTo(shape) }; - } + var shapeExpr = context.GetArgument(target, Expand.Shape); + return input with { Shape = Shape.FromExpr(shapeExpr) }; } private IRType Visit(ITypeInferenceContext context, Expand target, DistributedType input, TensorType shape) diff --git a/src/Nncase.Evaluator/Tensors/Flatten.cs b/src/Nncase.Evaluator/Tensors/Flatten.cs index 23db9ffa3..c3eb97a62 100644 --- a/src/Nncase.Evaluator/Tensors/Flatten.cs +++ b/src/Nncase.Evaluator/Tensors/Flatten.cs @@ -44,15 +44,13 @@ private IRType Visit(ITypeInferenceContext context, Flatten target, TensorType i { if (context.GetArgument(target, Flatten.Axis) is TensorConst axisV) { - if (input.Shape.IsFixed) - { - var axisValue = Util.PositiveIndex(axisV.Value.ToScalar(), input); - var first = input.Shape.Take(axisValue).Aggregate(1, (x, y) => x * y.FixedValue); - var second = input.Shape.Take(axisValue..input.Shape.Count).Aggregate(1, (x, y) => x * y.FixedValue); - return input with { Shape = new[] { first, second } }; - } + var axisValue = Util.PositiveIndex(axisV.Value.ToScalar(), input); + var first = input.Shape.Take(axisValue).Aggregate((Dimension)1, (x, y) => x * y); + var second = input.Shape.Take(axisValue..input.Shape.Count).Aggregate((Dimension)1, (x, y) => x * y); + return input with { Shape = new[] { first, second } }; } - return input with { Shape = Shape.Unknown(2) }; + // return input with { Shape = Shape.Unknown(2) }; + throw new NotImplementedException(); } } diff --git a/src/Nncase.Evaluator/Tensors/GatherND.cs b/src/Nncase.Evaluator/Tensors/GatherND.cs index 4a8f25a5a..e3653cde4 100644 --- a/src/Nncase.Evaluator/Tensors/GatherND.cs +++ b/src/Nncase.Evaluator/Tensors/GatherND.cs @@ -73,7 +73,7 @@ private IRType Visit(ITypeInferenceContext context, GatherND target, TensorType // result shape = index_shape[:-1] + input_shape[index_shape[-1] + batch_dims:] var dimensions = index.Shape.ToArray()[..(index.Shape.Rank - 1)]; - var d = lastIndexDims.FixedValue + batchDimsValue.Value.ToScalar(); + var d = (int)lastIndexDims.FixedValue + batchDimsValue.Value.ToScalar(); var shapeValue = dimensions.Concat(input.Shape.ToArray()[d..]); return new TensorType(input.DType, new IR.Shape(shapeValue)); } diff --git a/src/Nncase.Evaluator/Tensors/GetItem.cs b/src/Nncase.Evaluator/Tensors/GetItem.cs index 4c3d3a76d..b6ffd8603 100644 --- a/src/Nncase.Evaluator/Tensors/GetItem.cs +++ b/src/Nncase.Evaluator/Tensors/GetItem.cs @@ -73,6 +73,21 @@ public Expr Visit(IShapeEvaluateContext context, GetItem target) } } + public IRType Visit(ITypeInferenceContext context, GetItem target) + { + var input = context.CheckArgumentType(target, GetItem.Input); + var index = context.CheckArgumentType(target, GetItem.Index); + + return input switch + { + TensorType t => Visit(context, target, t, index), + DistributedType d => Visit(context, target, d, index), + TupleType t => Visit(context, target, t, index), + AnyType => AnyType.Default, + _ => new InvalidType(input.GetType().Name), + }; + } + private IValue Visit(IValue input, IValue index) { if (input.Type is TensorType ttype) @@ -85,88 +100,99 @@ private IValue Visit(IValue input, IValue index) var indicesValue = indices.Select((x, i) => x < 0 ? x + tensor.Shape[i].FixedValue : x).ToArray(); var linearIndex = TensorUtilities.GetIndex(tensor.Strides, indicesValue); - var returnDims = tensor.Dimensions.AsValueEnumerable().Skip(indexTensor.Length).ToArray(); - var elementsCount = (int)TensorUtilities.GetProduct(returnDims); + var returnDims = tensor.Dimensions.AsValueEnumerable().Skip((int)indexTensor.Length).ToArray(); + var elementsCount = TensorUtilities.GetProduct(returnDims); - var src = tensor.BytesBuffer.Slice(elementSize * linearIndex, elementSize * elementsCount); + var src = tensor.BytesBuffer.Slice(checked((int)(elementSize * linearIndex)), checked((int)(elementSize * elementsCount))); return Value.FromTensor(Tensor.FromBytes(new TensorType(ttype.DType, returnDims), src.ToArray())); } return input[index.AsTensor().ToScalar()]; } - private IRType Visit(ITypeInferenceContext context, GetItem target, IRType input, TensorType index) + private IRType Visit(ITypeInferenceContext context, GetItem target, TensorType input, TensorType index) { - IRType ret = new InvalidType("Need Be Reset!"); var indexExpr = context.GetArgument(target, GetItem.Index); - switch (input) + if (input.Shape.IsUnranked) { - case TensorType tensorType: - if (tensorType.Shape.IsUnranked) - { - return input; - } + return input; + } - if (indexExpr is TensorConst indexV) - { - var indices = indexV.Value.ToArray(); - if (indices.Length > tensorType.Shape.Rank) - { - return new InvalidType("GetItem index count should smaller than in shape rank"); - } + if (indexExpr is TensorConst indexV) + { + var indices = indexV.Value.ToArray(); + if (indices.Length > input.Shape.Rank) + { + return new InvalidType("GetItem index count should smaller than in shape rank"); + } - if (indices.Length == tensorType.Shape.Rank) + if (indices.Length == input.Shape.Rank) + { + foreach (var (i, dim) in indices.Zip(input.Shape)) + { + if (dim.IsFixed && i >= dim.FixedValue) { - foreach (var (i, dim) in indices.Zip(tensorType.Shape)) - { - if (dim.IsFixed && i >= dim.FixedValue) - { - return new InvalidType("GetItem index value shoud smaller than shape dim"); - } - } + return new InvalidType("GetItem index value shoud smaller than shape dim"); } } + } + } - var shape = index.Shape switch - { - { IsScalar: true } => new Shape(tensorType.Shape.Skip(1)), - { IsFixed: true } => index.Shape[0].FixedValue == tensorType.Shape.Rank ? - Shape.Scalar : - new Shape(tensorType.Shape.Skip(index.Shape[0].FixedValue)), - _ => Shape.Unranked, - }; - ret = new TensorType(tensorType.DType, shape); - break; - case TupleType tupleType: - if (indexExpr is TensorConst @const) + var shape = index.Shape switch + { + { IsScalar: true } => new Shape(input.Shape.Skip(1)), + { IsFixed: true } => index.Shape[0].FixedValue == input.Shape.Rank ? + Shape.Scalar : + new Shape(input.Shape.Skip((int)index.Shape[0].FixedValue)), + _ => Shape.Unranked, + }; + return new TensorType(input.DType, shape); + } + + private IRType Visit(ITypeInferenceContext context, GetItem target, TupleType input, TensorType index) + { + var indexExpr = context.GetArgument(target, GetItem.Index); + if (indexExpr is TensorConst @const) + { + var indexValue = @const.Value.ToScalar(); + if (indexValue < input.Count) + { + return input[indexValue]; + } + else + { + if (input.IsVariadic) { - var indexValue = @const.Value.ToScalar(); - if (indexValue < tupleType.Count) - { - ret = tupleType[indexValue]; - } - else - { - if (tupleType.IsVariadic) - { - ret = tupleType[0]; - } - else - { - ret = new InvalidType($"The Input Tuple Count = {tupleType.Count}, But Index = {indexValue}"); - } - } + return input[0]; } else { - ret = AnyType.Default; + return new InvalidType($"The Input Tuple Count = {input.Count}, But Index = {indexValue}"); } + } + } + else + { + return AnyType.Default; + } + } - break; - default: - break; + private IRType Visit(ITypeInferenceContext context, GetItem target, DistributedType input, TensorType index) + { + var outputType = (TensorType)Visit(context, target, input.TensorType, index); + var ndsbp = input.NdSBP.ToArray(); + for (var i = 0; i < ndsbp.Length; i++) + { + if (ndsbp[i] is SBPSplit { Axis: int axis }) + { + if ((index.Shape.IsScalar && axis == 0) + || axis < index.Shape[0].FixedValue) + { + ndsbp[i] = SBP.B; + } + } } - return ret; + return new DistributedType(outputType, ndsbp, input.Placement); } } diff --git a/src/Nncase.Evaluator/Tensors/Range.cs b/src/Nncase.Evaluator/Tensors/Range.cs index dac1afb1d..25cf1c6e6 100644 --- a/src/Nncase.Evaluator/Tensors/Range.cs +++ b/src/Nncase.Evaluator/Tensors/Range.cs @@ -1,6 +1,7 @@ // Copyright (c) Canaan Inc. All rights reserved. // Licensed under the Apache license. See LICENSE file in the project root for full license information. +using System; using Nncase.CostModel; using Nncase.IR; using Nncase.Utilities; @@ -65,7 +66,8 @@ public IRType Visit(ITypeInferenceContext context, Range target) return new InvalidType("DataType is unknown"); } - return new TensorType(dt, new Shape(Dimension.Unknown)); + // return new TensorType(dt, new Shape(Dimension.Unknown)); + throw new NotImplementedException(); } } diff --git a/src/Nncase.Evaluator/Tensors/Reshape.cs b/src/Nncase.Evaluator/Tensors/Reshape.cs index ab211fdca..8b7e3b396 100644 --- a/src/Nncase.Evaluator/Tensors/Reshape.cs +++ b/src/Nncase.Evaluator/Tensors/Reshape.cs @@ -91,11 +91,12 @@ private IRType Visit(ITypeInferenceContext context, Reshape target, TensorType i return input; } - if (context.GetArgument(target, Reshape.Shape) is TensorConst shapeConst) + var shape = context.GetDimensionArgument(target, Reshape.Shape); + if (shape is TensorConst shapeConst) { - var shapeValue = shapeConst.Value.ToArray(); + var shapeValue = shapeConst.Value.ToArray(); var negCount = shapeValue.Count(IsMinus1); - var shapeSize = shapeValue.Aggregate(1, (x, y) => x * y); + var shapeSize = shapeValue.Aggregate(1L, (x, y) => x * y); if (negCount > 1) { return new InvalidType( @@ -133,14 +134,18 @@ private IRType Visit(ITypeInferenceContext context, Reshape target, TensorType i { return input with { - Shape = new Shape(shapeValue.Select(x => x == -1 ? Dimension.Unknown : x).ToArray()), + Shape = new Shape(shapeValue.Select(x => x == -1 ? Dimension.Unknown() : x).ToArray()), }; } } - var targetType = context.CheckArgumentType(target, Reshape.Shape); - var outShape = ReshapeTo(targetType); - return input with { Shape = outShape }; + if (!shape.CheckedShape.IsRanked || !shape.CheckedShape[0].IsFixed) + { + return input with { Shape = Shape.Unranked }; + } + + var newShape = Shape.Unknown((int)shape.CheckedShape[0].FixedValue); + return input with { Shape = newShape }; } private IRType Visit(ITypeInferenceContext context, Reshape target, DistributedType inputType) @@ -157,24 +162,27 @@ private IRType Visit(ITypeInferenceContext context, Reshape target, DistributedT return invalid; } - var newShape = outTensorType.Shape.ToValueArray(); - var oldShape = inputType.TensorType.Shape.ToValueArray(); + var newShape = outTensorType.Shape; + var oldShape = inputType.TensorType.Shape; // check is unsequeeze/sequeeze if (Enumerable.SequenceEqual(oldShape.Where(i => i != 1).ToArray(), newShape.Where(i => i != 1).ToArray())) { - if (oldShape.Length < newShape.Length) + if (oldShape.Count < newShape.Count) { var axis = 0; var axisMap = new Dictionary(); - for (var n = 0; n < newShape.Length; n++) + if (!oldShape.IsScalar) { - if (newShape[n] == oldShape[axis]) + for (var n = 0; n < newShape.Count; n++) { - axisMap.Add(axis++, n); - if (axis >= oldShape.Length) + if (newShape[n] == oldShape[axis]) { - break; + axisMap.Add(axis++, n); + if (axis >= oldShape.Count) + { + break; + } } } } @@ -191,16 +199,16 @@ private IRType Visit(ITypeInferenceContext context, Reshape target, DistributedT return inputType with { TensorType = outTensorType, NdSBP = new(ndsbp) }; } - else if (oldShape.Length > newShape.Length) + else if (oldShape.Count > newShape.Count) { var axis = 0; var axisMap = new Dictionary(); - for (var o = 0; o < oldShape.Length; o++) + for (var o = 0; o < oldShape.Count; o++) { - if (axis < newShape.Length && oldShape[o] == newShape[axis]) + if (axis < newShape.Count && oldShape[o] == newShape[axis]) { axisMap.Add(o, axis++); - if (axis >= newShape.Length) + if (axis >= newShape.Count) { break; } diff --git a/src/Nncase.Evaluator/Tensors/ScatterND.cs b/src/Nncase.Evaluator/Tensors/ScatterND.cs index b8e80e107..5bde446a8 100644 --- a/src/Nncase.Evaluator/Tensors/ScatterND.cs +++ b/src/Nncase.Evaluator/Tensors/ScatterND.cs @@ -29,7 +29,7 @@ public IValue Visit(IEvaluateContext context, ScatterND target) var input = context.GetArgumentValueAsTensor(target, ScatterND.Input); var indices = context.GetArgumentValueAsTensor(target, ScatterND.Indices); var updates = context.GetArgumentValueAsTensor(target, ScatterND.Updates); - var update_indices = indices.Shape.ToValueArray().Take(0..(indices.Shape.Rank - 1)).Select(i => Enumerable.Range(0, i)); + var update_indices = indices.Shape.ToValueArray().Take(0..(indices.Shape.Rank - 1)).Select(i => Enumerable.Range(0, (int)i).Select(x => (long)x)); var output = Tensor.FromBytes(input.ElementType, input.BytesBuffer.ToArray(), input.Shape); var indicesSpan = indices.Buffer.Span; var updatesSpan = updates.BytesBuffer; @@ -41,9 +41,9 @@ public IValue Visit(IEvaluateContext context, ScatterND target) var outputSpanStride = output.Strides.ToArray().SkipLast(updatesRemain.Count()).Select(s => s * input.ElementType.SizeInBytes).ToArray(); foreach (var idx in LinqExtensions.CartesianProduct(update_indices)) { - var index = indicesSpan.Slice(TensorUtilities.GetIndex(indicesSpanStride, idx.ToArray()), indices.Shape.ToValueArray()[^1]); - var updatesSlice = updatesSpan.Slice(TensorUtilities.GetIndex(updatesSliceStride, idx.ToArray()), updateSize); - updatesSlice.CopyTo(outputSpan.Slice(TensorUtilities.GetIndex(outputSpanStride, index.ToArray()))); + var index = indicesSpan.Slice(checked((int)TensorUtilities.GetIndex(indicesSpanStride, idx.ToArray())), checked((int)indices.Shape.ToValueArray()[^1])); + var updatesSlice = updatesSpan.Slice(checked((int)TensorUtilities.GetIndex(updatesSliceStride, idx.ToArray())), checked((int)updateSize)); + updatesSlice.CopyTo(outputSpan.Slice(checked((int)TensorUtilities.GetIndex(outputSpanStride, index.ToArray().ToLongs())))); } return Value.FromTensor(output); diff --git a/src/Nncase.Evaluator/Tensors/ShapeOf.cs b/src/Nncase.Evaluator/Tensors/ShapeOf.cs index 3b3a2c197..43658f541 100644 --- a/src/Nncase.Evaluator/Tensors/ShapeOf.cs +++ b/src/Nncase.Evaluator/Tensors/ShapeOf.cs @@ -27,14 +27,20 @@ public IValue Visit(IEvaluateContext context, ShapeOf shape) /// public IRType Visit(ITypeInferenceContext context, ShapeOf target) { - var input = context.CheckArgumentType(target, ShapeOf.Input); - return Visit(context, target, input); + var input = context.CheckArgumentType(target, ShapeOf.Input); + return input switch + { + TensorType t => Visit(context, target, t), + DistributedType d => Visit(context, target, d), + AnyType => AnyType.Default, + _ => new InvalidType(input.GetType().Name), + }; } /// public Cost Visit(ICostEvaluateContext context, ShapeOf target) { - var outputType = context.GetReturnType(); + var outputType = context.GetReturnType(); return new() { @@ -49,7 +55,7 @@ public Expr Visit(IShapeEvaluateContext context, ShapeOf target) public Metric Visit(IMetricEvaluateContext context, ShapeOf target) { - var outputType = context.GetReturnType(); + var outputType = context.GetReturnType(); return new() { @@ -65,6 +71,18 @@ private IRType Visit(ITypeInferenceContext context, ShapeOf target, TensorType i return new TensorType(DataTypes.Int64, new Shape(input.Shape.Rank)); } - return new TensorType(DataTypes.Int64, new Shape(Dimension.Unknown)); + return new TensorType(DataTypes.Int64, Shape.Unknown(1)); + } + + private IRType Visit(ITypeInferenceContext context, ShapeOf target, DistributedType input) + { + var outType = Visit(context, target, input.TensorType); + if (outType is not TensorType tensorType) + { + return new InvalidType("not support input tensor type infer"); + } + + var ndsbp = Enumerable.Repeat(SBP.B, input.Placement.Rank).ToArray(); + return new DistributedType(tensorType, ndsbp, input.Placement); } } diff --git a/src/Nncase.Evaluator/Tensors/Slice.cs b/src/Nncase.Evaluator/Tensors/Slice.cs index 1e44f7808..3c945399e 100644 --- a/src/Nncase.Evaluator/Tensors/Slice.cs +++ b/src/Nncase.Evaluator/Tensors/Slice.cs @@ -43,10 +43,10 @@ public IValue Visit(IEvaluateContext context, Slice sl) public IRType Visit(ITypeInferenceContext context, Slice target) { var input = context.CheckArgumentType(target, Slice.Input); - context.CheckArgumentType(target, Slice.Begins); - context.CheckArgumentType(target, Slice.Ends); - context.CheckArgumentType(target, Slice.Axes); - context.CheckArgumentType(target, Slice.Strides); + context.CheckArgumentTensorTypeOrBroadcast(target, Slice.Begins); + context.CheckArgumentTensorTypeOrBroadcast(target, Slice.Ends); + context.CheckArgumentTensorTypeOrBroadcast(target, Slice.Axes); + context.CheckArgumentTensorTypeOrBroadcast(target, Slice.Strides); return input switch { TensorType t => Visit(context, target, t), @@ -117,7 +117,7 @@ public Expr Visit(IShapeEvaluateContext context, Slice target) /// Axis. /// Input type. /// (index in axis, axis, inDim) -> outDim. - private Shape ApplyAxis(TensorConst axisConst, TensorType input, Func f) + private Shape ApplyAxis(TensorConst axisConst, TensorType input, Func f) { if (input.Shape.IsUnranked) { @@ -132,9 +132,7 @@ private Shape ApplyAxis(TensorConst axisConst, TensorType input, Func Dimension.Unknown); + // outShape = ApplyAxis(axes_con, input, (i, axis, inDim) => Dimension.Unknown); + outShape = Shape.Unknown(input.Shape.Rank); } } else diff --git a/src/Nncase.Evaluator/Tensors/Split.cs b/src/Nncase.Evaluator/Tensors/Split.cs index a6c565c87..60179cd2e 100644 --- a/src/Nncase.Evaluator/Tensors/Split.cs +++ b/src/Nncase.Evaluator/Tensors/Split.cs @@ -74,7 +74,7 @@ private IRType Visit(ITypeInferenceContext context, Split target, TensorType inp if (input.Shape.IsUnranked) { - return new TupleType(Enumerable.Repeat((IRType)(input with { Shape = Shape.Unranked }), sections_v.Length)); + return new TupleType(Enumerable.Repeat((IRType)(input with { Shape = Shape.Unranked }), (int)sections_v.Length)); } var inshape = input.Shape.ToArray(); @@ -111,13 +111,14 @@ private IRType Visit(ITypeInferenceContext context, Split target, TensorType inp if (context.GetArgument(target, Split.Axis) is TensorConst axisCon) { var axisV = Util.PositiveIndex(axisCon.Value.ToScalar(), input.Shape.Rank); - splitedShape[axisV] = Dimension.Unknown; + splitedShape[axisV] = Dimension.Unknown(); } else { - splitedShape = splitedShape.Select(s => Dimension.Unknown).ToArray(); + splitedShape = splitedShape.Select(s => Dimension.Unknown()).ToArray(); } - return new TupleType(new IRType[] { input with { Shape = splitedShape } }, true); + // return new TupleType(new IRType[] { input with { Shape = splitedShape } }, true); + throw new NotImplementedException(); } } diff --git a/src/Nncase.Evaluator/Tensors/Squeeze.cs b/src/Nncase.Evaluator/Tensors/Squeeze.cs index 19573e720..9cc1dd585 100644 --- a/src/Nncase.Evaluator/Tensors/Squeeze.cs +++ b/src/Nncase.Evaluator/Tensors/Squeeze.cs @@ -73,6 +73,7 @@ private IRType Visit(ITypeInferenceContext context, Squeeze target, TensorType i return input with { Shape = new Shape(outshape.Where(x => x != int.MaxValue)) }; } - return input with { Shape = new Shape(Enumerable.Repeat(Dimension.Unknown, input.Shape.Count - 1)) }; + // return input with { Shape = new Shape(Enumerable.Repeat(Dimension.Unknown, input.Shape.Count - 1)) }; + throw new NotImplementedException(); } } diff --git a/src/Nncase.Evaluator/Tensors/Tile.cs b/src/Nncase.Evaluator/Tensors/Tile.cs index 19bebffdf..a12469331 100644 --- a/src/Nncase.Evaluator/Tensors/Tile.cs +++ b/src/Nncase.Evaluator/Tensors/Tile.cs @@ -61,19 +61,18 @@ public Metric Visit(IMetricEvaluateContext context, Tile target) private IRType Visit(ITypeInferenceContext context, Tile target, TensorType input, TensorType repeat) { - if (input.Shape.IsUnranked) + var inShape = input.Shape; + var repeats = context.GetArgument(target, Tile.Repeats); + if (repeats is TensorConst tc) { - return input; + var repeatsValue = tc.Value.ToArray(); + var shape = input.Shape.Zip(repeatsValue).Select(p => p.First * p.Second); + return input with { Shape = new Shape(shape) }; } - - if (context.GetArgument(target, Tile.Repeats) is TensorConst repeats && input.Shape.IsFixed) + else { - var shape = input.Shape.ToValueArray().Zip(repeats.Value.ToArray()).Select(p => p.First * p.Second); - return input with { Shape = new Shape(shape.ToArray()) }; + var shape = input.Shape.Select((p, i) => p * repeats[i]); + return input with { Shape = new Shape(shape) }; } - - return new TensorType( - input.DType, - new Shape(Enumerable.Repeat(Dimension.Unknown, input.Shape.Rank))); } } diff --git a/src/Nncase.Evaluator/Tensors/UnSqueeze.cs b/src/Nncase.Evaluator/Tensors/UnSqueeze.cs index e8312c491..76e1c7ad5 100644 --- a/src/Nncase.Evaluator/Tensors/UnSqueeze.cs +++ b/src/Nncase.Evaluator/Tensors/UnSqueeze.cs @@ -28,7 +28,6 @@ public IValue Visit(IEvaluateContext context, Unsqueeze unSqueeze) public IRType Visit(ITypeInferenceContext context, Unsqueeze target) { var input = context.CheckArgumentType(target, Unsqueeze.Input); - _ = context.CheckArgumentType(target, Unsqueeze.Dim); if (input is TensorType tensorType) { return Visit(context, target, tensorType); @@ -66,7 +65,7 @@ private IRType Visit(ITypeInferenceContext context, Unsqueeze target, TensorType return input; } - if (context.GetArgument(target, Unsqueeze.Dim) is TensorConst axes) + if (context.GetDimensionArgument(target, Unsqueeze.Dim) is TensorConst axes) { var axesValue = axes.Value.ToArray(); var outShape = new Dimension[input.Shape.Rank + axesValue.Length]; @@ -87,7 +86,7 @@ private IRType Visit(ITypeInferenceContext context, Unsqueeze target, TensorType return input with { Shape = new Shape(outShape) }; } - return input with { Shape = new Shape(Enumerable.Repeat(Dimension.Unknown, input.Shape.Rank + 1)) }; + return input with { Shape = Shape.Unknown(input.Shape.Rank + 1) }; } private IRType Visit(ITypeInferenceContext context, Unsqueeze target, DistributedType input) diff --git a/src/Nncase.Evaluator/Tensors/Where.cs b/src/Nncase.Evaluator/Tensors/Where.cs index f45d41b73..d80c9fcd1 100644 --- a/src/Nncase.Evaluator/Tensors/Where.cs +++ b/src/Nncase.Evaluator/Tensors/Where.cs @@ -62,7 +62,7 @@ public IRType Visit(TensorType cond, TensorType x, TensorType y, Where target) { if (target.IsTfWhere) { - return new TensorType(DataTypes.Int64, new Shape(Dimension.Unknown, cond.Shape.Rank)); + return new TensorType(DataTypes.Int64, new Shape(Dimension.Unknown(), cond.Shape.Rank)); } return TypeInference.BroadcastType(x.DType, cond, x, y); diff --git a/src/Nncase.Evaluator/TypeInference.cs b/src/Nncase.Evaluator/TypeInference.cs index ce1fda29d..58c554175 100644 --- a/src/Nncase.Evaluator/TypeInference.cs +++ b/src/Nncase.Evaluator/TypeInference.cs @@ -56,6 +56,27 @@ T WrapperException(T t) }; } + public static TensorType CheckArgumentTensorTypeOrBroadcast(this ITypeInferenceContext context, Op op, ParameterInfo parameter, string? reason = null) + { + TensorType WrapperException(TensorType t) + { + try + { + return parameter.Pattern.Check(t, $"{op.GetType().Name}.{parameter.Name}"); + } + catch (System.InvalidOperationException e) + { + throw new TypeInferenceInterruptException(new InvalidType(e.Message)); + } + } + + return context.GetArgumentType(op, parameter) switch + { + DistributedType d when d.NdSBP.All(x => x is SBPBroadCast) => WrapperException(d.TensorType), + IRType t => CheckArgumentType(context, op, parameter, reason), + }; + } + /// /// Throw if type is or . /// @@ -121,7 +142,7 @@ public static IRType BroadcastType(DataType dataType, params TensorType[] inputs var inExtend = outputRank - inShape.Rank; var inDimIndex = dimIndex - inExtend; var inDim = inDimIndex < 0 ? 1 : inShape[inDimIndex]; - if (inDim is Dimension { Value: 0 }) + if (inDim is Dimension { IsFixed: true, FixedValue: 0 }) { return new InvalidType("Input dimension should not be 0."); } @@ -129,30 +150,34 @@ public static IRType BroadcastType(DataType dataType, params TensorType[] inputs inputDims[i] = inDim; } - if (inputDims.All(x => x.IsFixed)) + var non1Dims = inputDims.Where(x => x.IsUnknown || x.FixedValue != 1).ToArray(); + if (non1Dims.Length == 0) { - // 1. Sort descending - Array.Sort(inputDims, (a, b) => b.FixedValue.CompareTo(a.FixedValue)); - - // 2. Find first 1 - var firstOneIndex = inputDims.IndexOf(1); - var expectedDim = inputDims[0]; - - // 3. Dims before 1 are all same or 1 is not found, it's ok to broadcast - if ((firstOneIndex == -1 && inputDims.AsValueEnumerable().Distinct().Count() == 1) || - ((firstOneIndex != -1) && inputDims[..firstOneIndex].AsValueEnumerable().All(x => x == expectedDim))) + outputShape[dimIndex] = 1; + } + else + { + var expectedDim = non1Dims[0]; + if (non1Dims.Length == 1) { outputShape[dimIndex] = expectedDim; } + else if (non1Dims.All(x => x.IsFixed)) + { + if (non1Dims.All(x => x == expectedDim)) + { + outputShape[dimIndex] = expectedDim; + } + else + { + return new InvalidType("Inputs are not compatible to broadcast."); + } + } else { - return new InvalidType("Inputs are not compatible to broadcast."); + outputShape[dimIndex] = Dimension.Unknown(); // IR.F.Math.Max(non1Dims.Select(x => x.Value)); } } - else - { - outputShape[dimIndex] = Dimension.Unknown; - } } return new TensorType(dataType, new Shape(outputShape)); @@ -170,8 +195,7 @@ public static IRType Conv2DType(TensorType input, TensorType weights, Expr strid var outShape = input.Shape.ToList(); outShape[1] = weights.Shape[0]; - if ( - stride is TensorConst strideValue && + if (stride is TensorConst strideValue && padding is TensorConst paddingValue && dilation is TensorConst dilation_con && groups is TensorConst groups_con && @@ -193,21 +217,22 @@ groups is TensorConst groups_con && } outShape[2] = GetWindowedOutputSize( - input.Shape[2].FixedValue + ts_padding[0, 0] + ts_padding[0, 1], - weights.Shape[2].FixedValue, + (int)input.Shape[2].FixedValue + ts_padding[0, 0] + ts_padding[0, 1], + (int)weights.Shape[2].FixedValue, ts_stride[0], ts_dilation[0], false); outShape[3] = GetWindowedOutputSize( - input.Shape[3].FixedValue + ts_padding[1, 0] + ts_padding[1, 1], - weights.Shape[3].FixedValue, + (int)input.Shape[3].FixedValue + ts_padding[1, 0] + ts_padding[1, 1], + (int)weights.Shape[3].FixedValue, ts_stride[1], ts_dilation[1], false); } else { - outShape[2] = outShape[3] = Dimension.Unknown; + // outShape[2] = outShape[3] = Dimension.Unknown; + throw new NotImplementedException(); } return input with { Shape = new Shape(outShape) }; @@ -236,7 +261,7 @@ public static IRType PadType(TensorType input, Expr pads, Expr pad) { var tpads = paddings.Value.Cast(); var newShape = input.Shape.ToList(); - int channel = tpads.Dimensions[0]; + int channel = (int)tpads.Dimensions[0]; for (int i = 0; i < channel; i++) { newShape[newShape.Count - channel + i] += tpads[i, 0] + tpads[i, 1]; @@ -274,11 +299,11 @@ padding is TensorConst paddingValue && var padh = ts_padding[0, 0] + ts_padding[0, 1]; var padw = ts_padding[1, 0] + ts_padding[1, 1]; outShape[2] = input.Shape[2].IsUnknown - ? Dimension.Unknown - : GetWindowedOutputSize(input.Shape[2].FixedValue + padh, ts_filter[0], ts_stride[0], 1, false, ceilModeV); + ? throw new NotImplementedException() + : GetWindowedOutputSize((int)input.Shape[2].FixedValue + padh, ts_filter[0], ts_stride[0], 1, false, ceilModeV); outShape[3] = input.Shape[3].IsUnknown - ? Dimension.Unknown - : GetWindowedOutputSize(input.Shape[3].FixedValue + padw, ts_filter[1], ts_stride[1], 1, false, ceilModeV); + ? throw new NotImplementedException() + : GetWindowedOutputSize((int)input.Shape[3].FixedValue + padw, ts_filter[1], ts_stride[1], 1, false, ceilModeV); return input with { Shape = new Shape(outShape) }; } @@ -429,6 +454,7 @@ public static IRType ResizeType(TensorType input, Expr newSize, TensorType? inpu } else { +#if false switch (out_shape.Length) { case 2 or 3: @@ -440,6 +466,8 @@ public static IRType ResizeType(TensorType input, Expr newSize, TensorType? inpu out_shape[^1] = Dimension.Unknown; break; } +#endif + throw new NotImplementedException(); } // for roi amount. @@ -454,21 +482,7 @@ public static IRType ResizeType(TensorType input, Expr newSize, TensorType? inpu /// /// input x is -1?. /// - public static bool IsMinus1(int x) => x == -1; - - public static Shape ReshapeTo(TensorType tensorType) - { - var shape = tensorType.Shape; - if (shape.IsRanked && shape[0].IsFixed) - { - Trace.Assert(shape.Count != 0); - return Shape.Unknown(shape[0].FixedValue); - } - else - { - return Shape.Unranked; - } - } + public static bool IsMinus1(long x) => x == -1; /// /// Infer CommonType for inputs. diff --git a/src/Nncase.Evaluator/TypeInferenceVisitor.cs b/src/Nncase.Evaluator/TypeInferenceVisitor.cs index ff2a00a75..745fe0652 100644 --- a/src/Nncase.Evaluator/TypeInferenceVisitor.cs +++ b/src/Nncase.Evaluator/TypeInferenceVisitor.cs @@ -83,12 +83,7 @@ protected override IRType VisitLeafBuffer(Nncase.TIR.Buffer expr) VerifySubField(expr, r, TypePatternUtility.IsIntegralScalar()); } - var type = new TensorType(expr.ElemType, expr.Dimensions.AsValueEnumerable().Select(e => e switch - { - TensorConst { Value: { Shape: { IsScalar: true } } t } => new Dimension(t.ToScalar()), - _ => Dimension.Unknown, - }).ToArray()); - + var type = new TensorType(expr.ElemType, new Shape(expr.Dimensions)); return type; } diff --git a/src/Nncase.Importer/Ncnn/Convolution.cs b/src/Nncase.Importer/Ncnn/Convolution.cs index 5885feef4..3b978643d 100644 --- a/src/Nncase.Importer/Ncnn/Convolution.cs +++ b/src/Nncase.Importer/Ncnn/Convolution.cs @@ -78,10 +78,10 @@ private Expr VisitConvolution(NcnnLayer layer) paddingW = new Expr[] { padLeft, padRight }; } - var stride = Tensor.From(new[] { strideH, strideW }, new[] { 2 }); - var dilation = Tensor.From(new[] { dilationH, dilationW }, new[] { 2 }); + var stride = Tensor.From(new[] { strideH, strideW }, [2]); + var dilation = Tensor.From(new[] { dilationH, dilationW }, [2]); var clampRange = ToFloatClampRange(activationType, activationParams); - var clamp = Tensor.From(new[] { clampRange.Min, clampRange.Max }, new[] { 2 }); + var clamp = Tensor.From(new[] { clampRange.Min, clampRange.Max }, [2]); var padding = Util.ConcatPadding(paddingH, paddingW); var weights = _modelBin.LoadFloat32(new[] { numOutput, numInput, kernelH, kernelW }, true); var bias = biasTerm != 0 ? _modelBin.LoadFloat32(new[] { numOutput }, false) : Tensor.FromScalar(0f, numOutput); diff --git a/src/Nncase.Importer/Ncnn/NcnnModelBin.cs b/src/Nncase.Importer/Ncnn/NcnnModelBin.cs index 3957e1fbb..9d3dd7a55 100644 --- a/src/Nncase.Importer/Ncnn/NcnnModelBin.cs +++ b/src/Nncase.Importer/Ncnn/NcnnModelBin.cs @@ -26,7 +26,7 @@ public Tensor LoadFloat32(ReadOnlySpan shape, bool detectType) { if (!detectType) { - var tensor = new Tensor(shape); + var tensor = new Tensor(shape.ToLongs()); _stream.ReadExactly(tensor.BytesBuffer); return tensor; } @@ -45,7 +45,7 @@ public Tensor LoadAuto(ReadOnlySpan shape) if (tag == 0x01306B47) { // half-precision data - var tensor = new Tensor(shape); + var tensor = new Tensor(shape.ToLongs()); _stream.ReadExactly(tensor.BytesBuffer); AlignStream(_stream, tensor.BytesBuffer.Length, 4); return tensor.Cast(CastMode.KDefault); @@ -53,7 +53,7 @@ public Tensor LoadAuto(ReadOnlySpan shape) else if (tag == 0) { // raw data - var tensor = new Tensor(shape); + var tensor = new Tensor(shape.ToLongs()); _stream.ReadExactly(tensor.BytesBuffer); return tensor; } diff --git a/src/Nncase.Importer/Ncnn/Pooling.cs b/src/Nncase.Importer/Ncnn/Pooling.cs index edc486d71..1d8236c6e 100644 --- a/src/Nncase.Importer/Ncnn/Pooling.cs +++ b/src/Nncase.Importer/Ncnn/Pooling.cs @@ -42,9 +42,9 @@ private Expr VisitPooling(NcnnLayer layer) 1 => (ReduceOp.Mean, 0f), _ => throw new NotSupportedException($"Unsupported pooling type: {poolingType}."), }; - var filter = Tensor.From(new[] { kernelH, kernelW }, new[] { 2 }); - var stride = Tensor.From(new[] { strideH, strideW }, new[] { 2 }); - var dilation = Tensor.FromScalar(0, new[] { 2, 2 }); + var filter = Tensor.From(new[] { kernelH, kernelW }, [2]); + var stride = Tensor.From(new[] { strideH, strideW }, [2]); + var dilation = Tensor.FromScalar(0, [2, 2]); if (globalPooling) { @@ -52,7 +52,7 @@ private Expr VisitPooling(NcnnLayer layer) } else if (adaptivePooling) { - var padding = Tensor.FromScalar(0, new[] { 2, 2 }); + var padding = Tensor.FromScalar(0, [2, 2]); var inShape = Tensors.ShapeOf(input); var w = inShape[3]; var h = inShape[2]; diff --git a/src/Nncase.Importer/Onnx/DataGatter.cs b/src/Nncase.Importer/Onnx/DataGatter.cs index de44804a5..65dc7a622 100644 --- a/src/Nncase.Importer/Onnx/DataGatter.cs +++ b/src/Nncase.Importer/Onnx/DataGatter.cs @@ -34,7 +34,7 @@ public sealed partial class OnnxImporter public Shape GetShape(ValueInfoProto v) { var shape = v.Type.TensorType.Shape.Dim; - var dimArr = GetDimArray(shape, d => d, _ => Dimension.Unknown, d => (Dimension)d.DimValue); + var dimArr = GetDimArray(shape, d => d, d => Dimension.Unknown(d.DimParam), d => (Dimension)d.DimValue); return new Shape(dimArr); } diff --git a/src/Nncase.Importer/Onnx/MatMul.cs b/src/Nncase.Importer/Onnx/MatMul.cs index 5c45cb855..d02d74e28 100644 --- a/src/Nncase.Importer/Onnx/MatMul.cs +++ b/src/Nncase.Importer/Onnx/MatMul.cs @@ -14,6 +14,7 @@ public partial class OnnxImporter private Expr VisitMatMul(in NodeProto op) { var (a, b) = GetInputExprs(op, 0, 1); + // /mlp_2/Mul_output_0、/mlp_3/Mul_output_0、/mlp_21/Mul_output_0 if (a.Metadata.OutputNames![0] == "/mlp_2/Mul_output_0") { @@ -82,4 +83,4 @@ private Expr VisitMatMul(in NodeProto op) } } } -} \ No newline at end of file +} diff --git a/src/Nncase.Importer/Onnx/Pad.cs b/src/Nncase.Importer/Onnx/Pad.cs index 9726f684b..87560aa85 100644 --- a/src/Nncase.Importer/Onnx/Pad.cs +++ b/src/Nncase.Importer/Onnx/Pad.cs @@ -23,7 +23,7 @@ private Expr PadV2(in NodeProto op) var input = GetInputExpr(op, 0); var padMode = GetPadMode(op); var paddings = GetIntsAttribute(op, "pads"); - var pads = Tensor.From(paddings, new[] { 2, 4 }); + var pads = Tensor.From(paddings, [2, 4]); var value = GetFloatAttribute(op, "value", 0f); return Pad(input, ToNncasePadFormat(pads), padMode, value); } diff --git a/src/Nncase.Importer/Onnx/QLinearConv.cs b/src/Nncase.Importer/Onnx/QLinearConv.cs index 3e1bb7e07..6d2709f34 100644 --- a/src/Nncase.Importer/Onnx/QLinearConv.cs +++ b/src/Nncase.Importer/Onnx/QLinearConv.cs @@ -29,20 +29,20 @@ private Expr VisitQLinearConv(in NodeProto op) var group = GetIntAttribute(op, "group", 1); var strides = GetStrideAttribute(op); - int? stridesValueLen = ((TensorConst)strides).CheckedShape[0].Value; + int stridesValueLen = (int)((TensorConst)strides).CheckedShape[0].FixedValue; for (var i = 0; i < stridesValueLen; i++) { System.Diagnostics.Trace.Assert(((TensorConst)strides).Value.Cast()[i] <= (long)int.MaxValue); } - int? dilationValueLen = ((TensorConst)dilation).CheckedShape[0].Value; + int dilationValueLen = (int)((TensorConst)dilation).CheckedShape[0].FixedValue; for (var i = 0; i < dilationValueLen; i++) { System.Diagnostics.Trace.Assert(((TensorConst)dilation).Value.Cast()[i] <= (long)int.MaxValue); } var pads = AutoPad(op, autoPad, input, weights, strides.ToArray(), dilation); - int[] strideArr = new int[stridesValueLen == null ? default : stridesValueLen.Value]; + int[] strideArr = new int[stridesValueLen]; for (var i = 0; i < stridesValueLen; i++) { strideArr[i] = ((TensorConst)strides).Value.Cast()[i]; @@ -50,7 +50,7 @@ private Expr VisitQLinearConv(in NodeProto op) var strideConst = new TensorConst(Tensor.From(strideArr)); - int[] dilationArr = new int[dilationValueLen == null ? default : dilationValueLen.Value]; + int[] dilationArr = new int[dilationValueLen]; for (var i = 0; i < dilationValueLen; i++) { dilationArr[i] = ((TensorConst)dilation).Value.Cast()[i]; @@ -63,7 +63,7 @@ private Expr VisitQLinearConv(in NodeProto op) if (bias == null) { - int? ocNumber = ((TensorConst)weights).CheckedShape[0].Value; + int? ocNumber = (int)((TensorConst)weights).CheckedShape[0].FixedValue; var zeroBias = new TensorConst(new int[ocNumber == null ? default(int) : ocNumber.Value]); var conv = F.NN.Conv2D(inputDeq, weightsDeq, zeroBias, strideConst, pads, dilationConst, PadMode.Constant, group); return Quantize(conv, new QuantParam(((TensorConst)yZeroPoint).Value.ToScalar(), ((TensorConst)yScale).Value.ToScalar()), ((TensorConst)yZeroPoint).CheckedDataType); diff --git a/src/Nncase.Importer/TFLite/Conv2D.cs b/src/Nncase.Importer/TFLite/Conv2D.cs index 260ee76d0..7a1e414e4 100644 --- a/src/Nncase.Importer/TFLite/Conv2D.cs +++ b/src/Nncase.Importer/TFLite/Conv2D.cs @@ -33,8 +33,8 @@ private Expr VisitConv2D(in tflite.Operator op) var strideW = options.StrideW; var dilationH = options.DilationHFactor; var dilationW = options.DilationWFactor; - var stride = Tensor.From(new[] { strideH, strideW }, new[] { 2 }); - var dilation = Tensor.From(new[] { dilationH, dilationW }, new[] { 2 }); + var stride = Tensor.From(new[] { strideH, strideW }, [2]); + var dilation = Tensor.From(new[] { dilationH, dilationW }, [2]); var padding = Util.GetPaddings(input, weights, stride, dilation, options.Padding == tflite.Padding.SAME, false); var clamp = ToFloatClampRange(options.FusedActivationFunction); var inputQuantParams = GetInputQuantParams(op, 0); @@ -123,8 +123,8 @@ private Expr VisitDepthwiseConv2D(in tflite.Operator op) var strideW = options.StrideW; var dilationH = options.DilationHFactor; var dilationW = options.DilationWFactor; - var stride = Tensor.From(new[] { strideH, strideW }, new[] { 2 }); - var dilation = Tensor.From(new[] { dilationH, dilationW }, new[] { 2 }); + var stride = Tensor.From(new[] { strideH, strideW }, [2]); + var dilation = Tensor.From(new[] { dilationH, dilationW }, [2]); var padding = Util.GetPaddings(input, weights, stride, dilation, options.Padding == tflite.Padding.SAME, false); var depthMul = options.DepthMultiplier; if (depthMul != 1) diff --git a/src/Nncase.Importer/TFLite/Conv2DTranspose.cs b/src/Nncase.Importer/TFLite/Conv2DTranspose.cs index 688cfe1f0..f552409df 100644 --- a/src/Nncase.Importer/TFLite/Conv2DTranspose.cs +++ b/src/Nncase.Importer/TFLite/Conv2DTranspose.cs @@ -36,8 +36,8 @@ private Expr VisitConv2DTranspose(in tflite.Operator op) var strideW = options.StrideW; var dilationH = 1; var dilationW = 1; - var stride = Tensor.From(new[] { strideH, strideW }, new[] { 2 }); - var dilation = Tensor.From(new[] { dilationH, dilationW }, new[] { 2 }); + var stride = Tensor.From(new[] { strideH, strideW }, [2]); + var dilation = Tensor.From(new[] { dilationH, dilationW }, [2]); var oldWShape = F.Tensors.ShapeOf(weights); var wShape = F.Tensors.Stack(new IR.Tuple(oldWShape[0], oldWShape[3], oldWShape[1], oldWShape[2]), 0); var padding = F.ShapeExpr.GetPaddings(F.Tensors.Stack(new IR.Tuple(newOutShape), 0), wShape, stride, dilation, options.Padding == tflite.Padding.SAME, false); diff --git a/src/Nncase.Importer/TFLite/ReduceWindow2D.cs b/src/Nncase.Importer/TFLite/ReduceWindow2D.cs index 1fabda70e..591ce9ae7 100644 --- a/src/Nncase.Importer/TFLite/ReduceWindow2D.cs +++ b/src/Nncase.Importer/TFLite/ReduceWindow2D.cs @@ -22,8 +22,8 @@ private Expr VisitReduceWindow2D(in tflite.Operator op, ReduceOp reduceOp, float var strideW = option.StrideW; var padH = Util.GetWindowedPadding(inH, filterH, strideH, 1, option.Padding == tflite.Padding.SAME); var padW = Util.GetWindowedPadding(inW, filterW, strideW, 1, option.Padding == tflite.Padding.SAME); - var filter = Tensor.From(new[] { filterH, filterW }, new[] { 2 }); - var stride = Tensor.From(new[] { strideH, strideW }, new[] { 2 }); + var filter = Tensor.From(new[] { filterH, filterW }, [2]); + var stride = Tensor.From(new[] { strideH, strideW }, [2]); var padding = Util.ConcatPadding(padH, padW); return F.Tensors.NCHWToNHWC( F.NN.ReduceWindow2D( diff --git a/src/Nncase.Importer/TFLite/TFLiteImporter.cs b/src/Nncase.Importer/TFLite/TFLiteImporter.cs index 8a6343933..bc190e555 100644 --- a/src/Nncase.Importer/TFLite/TFLiteImporter.cs +++ b/src/Nncase.Importer/TFLite/TFLiteImporter.cs @@ -156,7 +156,7 @@ private static Dimension[] GetShapeArray(tflite.Tensor tensor) } return Enumerable.Range(0, tensor.ShapeLength).Select(i => - tensor.ShapeSignature(i) < 0 ? Dimension.Unknown : tensor.Shape(i)).ToArray(); + tensor.ShapeSignature(i) < 0 ? Dimension.Unknown() : tensor.Shape(i)).ToArray(); } private void Visit(in tflite.Operator op) diff --git a/src/Nncase.Passes/BufferSchedule/BufferScheduleTypes.cs b/src/Nncase.Passes/BufferSchedule/BufferScheduleTypes.cs index b8809153b..ebed5d7d9 100644 --- a/src/Nncase.Passes/BufferSchedule/BufferScheduleTypes.cs +++ b/src/Nncase.Passes/BufferSchedule/BufferScheduleTypes.cs @@ -25,7 +25,7 @@ public override string ToString() public class ScheduleBuffer { - public ScheduleBuffer(string name, int number, Interval timeInterval, Interval memInterval, int[] shape, int[] strides, bool inplace) + public ScheduleBuffer(string name, int number, Interval timeInterval, Interval memInterval, long[] shape, long[] strides, bool inplace) { Name = name; Number = number; @@ -44,9 +44,9 @@ public ScheduleBuffer(string name, int number, Interval timeInterval, Interval m public Interval MemInterval { get; } - public int[] Shape { get; } + public long[] Shape { get; } - public int[] Strides { get; } + public long[] Strides { get; } public bool Inplace { get; } diff --git a/src/Nncase.Passes/BufferSchedule/LifeTimeCollector.cs b/src/Nncase.Passes/BufferSchedule/LifeTimeCollector.cs index 1c2febde0..1c63619b2 100644 --- a/src/Nncase.Passes/BufferSchedule/LifeTimeCollector.cs +++ b/src/Nncase.Passes/BufferSchedule/LifeTimeCollector.cs @@ -101,13 +101,13 @@ public override Result VisitType(DistributedType distributedType) public override Result VisitType(TupleType tupleType) { - var size = 0; + long size = 0; foreach (var item in tupleType) { size += VisitType(item).Size; } - return new(size, Array.Empty(), Array.Empty()); + return new(size, Array.Empty(), Array.Empty()); } protected override Result VisitCall(Call expr) @@ -128,9 +128,9 @@ protected override Result VisitCall(Call expr) return VisitType(expr.CheckedType); } - public sealed record Result(int Size, int[] Shape, int[] Stride) + public sealed record Result(long Size, long[] Shape, long[] Stride) { - public static readonly Result Empty = new(0, Array.Empty(), Array.Empty()); + public static readonly Result Empty = new(0, Array.Empty(), Array.Empty()); } } diff --git a/src/Nncase.Passes/GraphPartition/GraphConvetor.cs b/src/Nncase.Passes/GraphPartition/GraphConvetor.cs index cfbe2ad8f..b4bd7c627 100644 --- a/src/Nncase.Passes/GraphPartition/GraphConvetor.cs +++ b/src/Nncase.Passes/GraphPartition/GraphConvetor.cs @@ -266,18 +266,25 @@ private bool HasCycles(SortedDictionary subgraphMap, Dictionary +public class GraphConvertor : ExprVisitor { - private int _nodeCount; - - public GraphConvertor(Func predicate) + public GraphConvertor(Func predicate) { Predicate = predicate; } - public Func Predicate { get; } + public int NodeGlobalIndex { get; protected set; } + + public Func Predicate { get; } - protected override Unit VisitGrid(Grid expr, GraphContext context) + protected virtual void UpdateContext(Vertex target, GraphContext context) + { + context.Graph.AddVertex(target); + context.SubgraphMap.Add(NodeGlobalIndex, new Subgraph(NodeGlobalIndex, new() { target }, new List(), new List(), new List())); + NodeGlobalIndex++; + } + + protected override Vertex VisitGrid(Grid expr, GraphContext context) { foreach (var operand in expr.Reads) { @@ -287,21 +294,10 @@ protected override Unit VisitGrid(Grid expr, GraphContext context) return VisitLeafGrid(expr, context); } - protected override Unit VisitLeafGrid(Grid expr, GraphContext context) + protected override Vertex VisitLeafGrid(Grid expr, GraphContext context) { - Vertex target; - if (Predicate(expr)) - { - target = new Vertex(expr, Compat.COMPATIBLE); - } - else - { - target = new Vertex(expr, Compat.INCOMPATIBLE); - } - - context.Graph.AddVertex(target); - context.SubgraphMap.Add(_nodeCount, new Subgraph(_nodeCount, new() { target }, new List(), new List(), new List())); - _nodeCount++; + var target = new Vertex(expr, Predicate(expr)); + UpdateContext(target, context); foreach (var operand in expr.Reads) { if (context.Graph.Vertices.Any(v => ReferenceEquals(v.Expr, operand))) @@ -325,61 +321,58 @@ protected override Unit VisitLeafGrid(Grid expr, GraphContext context) } } - return default; + return target; } - protected override Unit VisitLeafVar(Var expr, GraphContext context) + protected override Vertex VisitLeafVar(Var expr, GraphContext context) { - Vertex target; - - target = new Vertex(expr, Compat.INCOMPATIBLE); - - context.Graph.AddVertex(target); - context.SubgraphMap.Add(_nodeCount, new Subgraph(_nodeCount, new() { target }, new List(), new List(), new List())); - _nodeCount++; + var target = new Vertex(expr, Predicate(expr)); + UpdateContext(target, context); + foreach (var operand in expr.Operands) + { + if (ExprMemo.TryGetValue(operand, out var source) && source is not null) + { + switch (source.CompatType, target.CompatType) + { + case (Compat.COMPATIBLE, Compat.INCOMPATIBLE): + context.Graph.AddEdge(new Edge(EdgeTypes.C2I, source, target)); + break; + case (Compat.INCOMPATIBLE, Compat.COMPATIBLE): + context.Graph.AddEdge(new Edge(EdgeTypes.I2C, source, target)); + break; + case (Compat.INCOMPATIBLE, Compat.INCOMPATIBLE): + context.Graph.AddEdge(new Edge(EdgeTypes.I2I, source, target)); + break; + default: + context.Graph.AddEdge(new Edge(EdgeTypes.C2C, source, target)); + break; + } + } + } - return default; + return target; } - protected override Unit VisitLeafConst(Const expr, GraphContext context) + protected override Vertex VisitLeafConst(Const expr, GraphContext context) { Vertex target; target = new Vertex(expr, expr.CheckedType is DistributedType ? Compat.COMPATIBLE : Compat.INCOMPATIBLE); - - context.Graph.AddVertex(target); - context.SubgraphMap.Add(_nodeCount, new Subgraph(_nodeCount, new() { target }, new List(), new List(), new List())); - _nodeCount++; - - return default; + UpdateContext(target, context); + return target; } - protected override Unit VisitLeafNone(None expr, GraphContext context) + protected override Vertex VisitLeafNone(None expr, GraphContext context) { Vertex target; target = new Vertex(expr, Compat.INCOMPATIBLE); - - context.Graph.AddVertex(target); - context.SubgraphMap.Add(_nodeCount, new Subgraph(_nodeCount, new() { target }, new List(), new List(), new List())); - _nodeCount++; - - return default; + UpdateContext(target, context); + return target; } - protected override Unit VisitLeafCall(Call expr, GraphContext context) + protected override Vertex VisitLeafCall(Call expr, GraphContext context) { - Vertex target; - if (Predicate(expr)) - { - target = new Vertex(expr, Compat.COMPATIBLE); - } - else - { - target = new Vertex(expr, Compat.INCOMPATIBLE); - } - - context.Graph.AddVertex(target); - context.SubgraphMap.Add(_nodeCount, new Subgraph(_nodeCount, new() { target }, new List(), new List(), new List())); - _nodeCount++; + var target = new Vertex(expr, Predicate(expr)); + UpdateContext(target, context); foreach (var operand in expr.Arguments) { if (context.Graph.Vertices.Any(v => ReferenceEquals(v.Expr, operand))) @@ -403,23 +396,20 @@ protected override Unit VisitLeafCall(Call expr, GraphContext context) } } - return default; + return target; } - protected override Unit VisitLeafTuple(IR.Tuple expr, GraphContext context) + protected override Vertex VisitLeafTuple(IR.Tuple expr, GraphContext context) { Vertex target; var compatType = context.Graph.Vertices.First(v => ReferenceEquals(v.Expr, expr.Fields[0])).CompatType; - if (!Predicate(expr)) + if (Predicate(expr) != Compat.COMPATIBLE) { compatType = Compat.INCOMPATIBLE; } target = new Vertex(expr, compatType); - - context.Graph.AddVertex(target); - context.SubgraphMap.Add(_nodeCount, new Subgraph(_nodeCount, new() { target }, new List(), new List(), new List())); - _nodeCount++; + UpdateContext(target, context); foreach (var field in expr.Fields) { if (context.Graph.Vertices.Any(v => ReferenceEquals(v.Expr, field))) @@ -443,11 +433,11 @@ protected override Unit VisitLeafTuple(IR.Tuple expr, GraphContext context) } } - return default; + return target; } - protected override Unit DefaultVisitLeaf(Expr expr, GraphContext context) + protected override Vertex DefaultVisitLeaf(Expr expr, GraphContext context) { - return default; + return null!; } } diff --git a/src/Nncase.Passes/PassesModule.cs b/src/Nncase.Passes/PassesModule.cs index b4e3caaff..eb9f96ec7 100644 --- a/src/Nncase.Passes/PassesModule.cs +++ b/src/Nncase.Passes/PassesModule.cs @@ -15,5 +15,6 @@ public void ConfigureServices(IRegistrator registrator) { registrator.Register(reuse: Reuse.Singleton); registrator.Register(reuse: Reuse.Singleton); + registrator.Register(reuse: Reuse.Singleton); } } diff --git a/src/Nncase.Passes/Rules/Arithmetic/IdentityLaw.cs b/src/Nncase.Passes/Rules/Arithmetic/IdentityLaw.cs index 8a2788f89..8e3efc1c2 100644 --- a/src/Nncase.Passes/Rules/Arithmetic/IdentityLaw.cs +++ b/src/Nncase.Passes/Rules/Arithmetic/IdentityLaw.cs @@ -60,7 +60,7 @@ public XAddNegX() public IPattern Pattern { get; } - private Expr? GetReplace(Expr x) => Tensor.FromBytes(x.CheckedDataType, new byte[x.CheckedDataType.SizeInBytes], Array.Empty()); + private Expr? GetReplace(Expr x) => Tensor.FromBytes(x.CheckedDataType, new byte[x.CheckedDataType.SizeInBytes], Array.Empty()); } /// diff --git a/src/Nncase.Passes/Rules/Neutral/AddPreProcess.cs b/src/Nncase.Passes/Rules/Neutral/AddPreProcess.cs index e0a6ba932..a6c6b46cd 100644 --- a/src/Nncase.Passes/Rules/Neutral/AddPreProcess.cs +++ b/src/Nncase.Passes/Rules/Neutral/AddPreProcess.cs @@ -129,7 +129,7 @@ protected override Task RunCoreAsync(IRModule module, RunPassContext o // Letterbox if (inputShape.Length == 4) { - int modelH, modelW; + long modelH, modelW; if (modelLayout != "NCHW") { @@ -179,13 +179,13 @@ protected override Task RunCoreAsync(IRModule module, RunPassContext o switch (mean.Length) { case 3 when inputShape.Length == 4: - meanCall = (Expr)Tensor.From(mean, new[] { 1, mean.Length, 1, 1 }); - stdCall = (Expr)Tensor.From(std, new[] { 1, std.Length, 1, 1 }); + meanCall = (Expr)Tensor.From(mean, [1, mean.Length, 1, 1]); + stdCall = (Expr)Tensor.From(std, [1, std.Length, 1, 1]); break; default: - meanCall = (Expr)Tensor.From(new float[] { mean[0] }, new[] { 1 }); - stdCall = (Expr)Tensor.From(new float[] { std[0] }, new[] { 1 }); + meanCall = (Expr)Tensor.From(new float[] { mean[0] }, [1]); + stdCall = (Expr)Tensor.From(new float[] { std[0] }, [1]); break; } diff --git a/src/Nncase.Passes/Rules/Neutral/AddToConv2D.cs b/src/Nncase.Passes/Rules/Neutral/AddToConv2D.cs index f57c9cd64..6147346b4 100644 --- a/src/Nncase.Passes/Rules/Neutral/AddToConv2D.cs +++ b/src/Nncase.Passes/Rules/Neutral/AddToConv2D.cs @@ -49,7 +49,7 @@ public sealed partial class AddToConv2D : IRewriteRule con_weights, bias: Tensor.FromScalar(0.0f, channels), stride: new[] { 1, 1 }, - padding: Tensor.From(new[] { 0, 0, 0, 0 }, new[] { 2, 2 }), + padding: Tensor.From(new[] { 0, 0, 0, 0 }, [2, 2]), dilation: new[] { 1, 1 }, padMode: PadMode.Constant, groups: 1); diff --git a/src/Nncase.Passes/Rules/Neutral/BatchNormToBinary.cs b/src/Nncase.Passes/Rules/Neutral/BatchNormToBinary.cs index 5e69dd90d..a1e709892 100644 --- a/src/Nncase.Passes/Rules/Neutral/BatchNormToBinary.cs +++ b/src/Nncase.Passes/Rules/Neutral/BatchNormToBinary.cs @@ -41,7 +41,7 @@ public sealed partial class BatchNormToBinary : IRewriteRule } var shape = input.CheckedShape.ToValueArray(); - var bnShape = Enumerable.Repeat(1, shape.Length - 1).ToArray(); + var bnShape = Enumerable.Repeat(1L, shape.Length - 1).ToArray(); bnShape[0] = shape[1]; var scaleBn = IR.F.Math.Div(gamma, IR.F.Math.Sqrt(IR.F.Math.Add(var, eps))).With(metadata: new IRMetadata() { OutputNames = new[] { bnCall.Metadata.OutputNames?[0] + "_Scale" } }); var biasBn = IR.F.Math.Sub(beta, IR.F.Math.Mul(gamma, IR.F.Math.Div(mean, IR.F.Math.Sqrt(IR.F.Math.Add(var, eps))))).With(metadata: new IRMetadata() { OutputNames = new[] { bnCall.Metadata.OutputNames?[0] + "_Bias" } }); diff --git a/src/Nncase.Passes/Rules/Neutral/CombineReshape.cs b/src/Nncase.Passes/Rules/Neutral/CombineReshape.cs index ce6881cf0..5f566094e 100644 --- a/src/Nncase.Passes/Rules/Neutral/CombineReshape.cs +++ b/src/Nncase.Passes/Rules/Neutral/CombineReshape.cs @@ -79,7 +79,7 @@ public CombineConstBinaryReshape() private Expr? GetReplace(Binary binary, Call call, IReadOnlyList callParams, Expr input, TensorConst constInput, TensorConst shape) { - var oldShape = shape.Value.ToArray(); + var oldShape = shape.Value.ToArray(); var significantShape = oldShape.Where(x => x > 1).ToArray(); bool leftConst = ReferenceEquals(callParams[0], constInput); @@ -98,7 +98,7 @@ public CombineConstBinaryReshape() if (significantShape.SequenceEqual(significantInputShape) && oldShape.Length > 0 && oldShape[^1] == constSize) { var broadcastIndex = Array.LastIndexOf(input.CheckedShape.ToValueArray(), constSize); - var newConstShape = Enumerable.Repeat(1, input.CheckedShape.Rank - 1 - broadcastIndex).ToList(); + var newConstShape = Enumerable.Repeat(1L, input.CheckedShape.Rank - 1 - broadcastIndex).ToList(); newConstShape.Insert(0, constSize); var res = Reshape(Binary(binary.BinaryOp, leftConst ? Reshape(constInput, newConstShape.ToArray()) : input, leftConst ? input : Reshape(constInput, newConstShape.ToArray())).InheritMetaData(call), call.CheckedShape); @@ -192,8 +192,8 @@ public sealed partial class CombineReshapePad : IRewriteRule && Enumerable.SequenceEqual(reshapeCall.CheckedShape.ToValueArray()[(reshapeRank - padRank)..], padCall.CheckedShape.ToValueArray())) { return Pad( - Reshape(input, Enumerable.Repeat(1, reshapeRank - padRank).Concat(input.CheckedShape.ToValueArray()).ToArray()).InheritMetaData(reshapeCall), - Tensor.From(Enumerable.Repeat(0, (reshapeRank - padRank) * 2).Concat(pads).ToArray(), new[] { reshapeRank, 2 }), + Reshape(input, Enumerable.Repeat(1L, reshapeRank - padRank).Concat(input.CheckedShape.ToValueArray()).ToArray()).InheritMetaData(reshapeCall), + Tensor.From(Enumerable.Repeat(0, (reshapeRank - padRank) * 2).Concat(pads).ToArray(), [reshapeRank, 2]), pad.PadMode, value).InheritMetaData(padCall); } @@ -229,7 +229,7 @@ public sealed partial class CombineReshapeTranspose : IRewriteRule { TypePattern = HasFixedShape() }, IsTensorConst("newShape")); - private Expr? GetReplace(Expr input, Call trans, int[] newShape, int[] perm) + private Expr? GetReplace(Expr input, Call trans, long[] newShape, int[] perm) { var transShape = trans.CheckedShape.ToValueArray(); diff --git a/src/Nncase.Passes/Rules/Neutral/CombineTranspose.cs b/src/Nncase.Passes/Rules/Neutral/CombineTranspose.cs index 3cfa6b4a8..1a460a2ad 100644 --- a/src/Nncase.Passes/Rules/Neutral/CombineTranspose.cs +++ b/src/Nncase.Passes/Rules/Neutral/CombineTranspose.cs @@ -96,7 +96,7 @@ public CombineConstBinaryTranspose() return Transpose(Binary(binary.BinaryOp, x, y).InheritMetaData(binaryCall), perm); } - var newShape = new List() { x.CheckedShape[0].FixedValue }; + var newShape = new List() { x.CheckedShape[0].FixedValue }; if (x.CheckedShape[0].FixedValue != 1) { for (int i = 0; i < expandDim; i++) @@ -116,7 +116,7 @@ public CombineConstBinaryTranspose() return Transpose(Binary(binary.BinaryOp, x, y).InheritMetaData(binaryCall), perm); } - var newShape = new List() { y.CheckedShape[0].FixedValue }; + var newShape = new List() { y.CheckedShape[0].FixedValue }; if (y.CheckedShape[0].FixedValue != 1) { for (int i = 0; i < expandDim; i++) @@ -148,10 +148,10 @@ public sealed partial class CombineTransposeConstBinary : RewriteRule(new[] { ln.Epsilon }, new[] { 1 }))); + var rsigma = IR.F.Math.Rsqrt(IR.F.Math.Add(sigma, Tensor.From(new[] { ln.Epsilon }, [1]))); return IR.F.Math.Add(IR.F.Math.Mul(IR.F.Math.Mul(sub, rsigma), scale), bias); } else { var sigma = IR.F.Tensors.ReduceMean(IR.F.Math.Square(input), new[] { ln.Axis }, 0f, true); - var rsigma = IR.F.Math.Rsqrt(IR.F.Math.Add(sigma, Tensor.From(new[] { ln.Epsilon }, new[] { 1 }))); + var rsigma = IR.F.Math.Rsqrt(IR.F.Math.Add(sigma, Tensor.From(new[] { ln.Epsilon }, [1]))); return IR.F.Math.Add(IR.F.Math.Mul(IR.F.Math.Mul(input, rsigma), scale), bias); } } diff --git a/src/Nncase.Passes/Rules/Neutral/FoldConv2DAddMul.cs b/src/Nncase.Passes/Rules/Neutral/FoldConv2DAddMul.cs index 2057e6975..0942935d6 100644 --- a/src/Nncase.Passes/Rules/Neutral/FoldConv2DAddMul.cs +++ b/src/Nncase.Passes/Rules/Neutral/FoldConv2DAddMul.cs @@ -91,7 +91,7 @@ private static bool CheckConstTensor(Tensor t) private Expr? GetReplace(Call conv2dCall, IR.NN.Conv2D conv2d, Tensor weights, Tensor bias, Expr strides, Expr paddings, Expr dilation, Expr groups, Expr fusedClamp, Tensor addConst, Tensor mulConst, Expr input) { - int ic = weights.Shape[1].FixedValue; + long ic = weights.Shape[1].FixedValue; if (mulConst.Length != ic || addConst.Length != ic) { return null; diff --git a/src/Nncase.Passes/Rules/Neutral/FoldDilatedConv2D.cs b/src/Nncase.Passes/Rules/Neutral/FoldDilatedConv2D.cs index 0385b480c..e42615b99 100644 --- a/src/Nncase.Passes/Rules/Neutral/FoldDilatedConv2D.cs +++ b/src/Nncase.Passes/Rules/Neutral/FoldDilatedConv2D.cs @@ -101,9 +101,9 @@ private static CallPattern Conv2DPattern() => return res; } - private (int[] Begin, int[] End) GetBeginEnd(int[] btsBlockShape, int[,] crop, int[] btsInputShape) + private (long[] Begin, long[] End) GetBeginEnd(int[] btsBlockShape, int[,] crop, long[] btsInputShape) { - List shape_expend = new(); + List shape_expend = new(); var block_shape_produt = btsBlockShape.Aggregate((x, sum) => x * sum); for (var i = 0; i < btsBlockShape.Length; i++) { @@ -116,7 +116,7 @@ private static CallPattern Conv2DPattern() => shape_expend.Add(btsInputShape[i]); } - List shape_shrink = new(); + List shape_shrink = new(); shape_shrink.Add(shape_expend[btsBlockShape.Length]); for (var i = 0; i < btsBlockShape.Length; i++) { @@ -128,7 +128,7 @@ private static CallPattern Conv2DPattern() => shape_shrink.Add(btsInputShape[i]); } - List crop_begs = new(), crop_ends = new(); + List crop_begs = new(), crop_ends = new(); crop_begs.Add(0); crop_ends.Add(shape_shrink[0]); for (var i = 0; i < crop.GetLength(0); i++) @@ -145,15 +145,15 @@ private static CallPattern Conv2DPattern() => var cropBegin = crop_begs.ToArray(); var cropEnd = crop_ends.ToArray(); - var strides = Enumerable.Repeat(1, crop_begs.Count).ToArray(); + var strides = Enumerable.Repeat(1L, crop_begs.Count).ToArray(); var begin = NormalizeStridedSliceBegin(btsInputShape, cropBegin, strides, 0); var end = NormalizeStridedSliceEndEnd(btsInputShape, begin, cropEnd, strides, 0, 0); return (begin, end); } - private int[] NormalizeStridedSliceEndEnd(int[] in_shape, int[] begin, int[] end, int[] strides, int end_mask, int shrink_axis_mask) + private long[] NormalizeStridedSliceEndEnd(long[] in_shape, long[] begin, long[] end, long[] strides, int end_mask, int shrink_axis_mask) { - var new_shape = Enumerable.Range(0, strides.Length).ToArray(); + var new_shape = Enumerable.Range(0, strides.Length).ToArray().ToLongs(); for (var i = 0; i < new_shape.Length; i++) { var stride = strides[i]; @@ -167,9 +167,9 @@ private int[] NormalizeStridedSliceEndEnd(int[] in_shape, int[] begin, int[] end return new_shape; } - private int[] NormalizeStridedSliceBegin(int[] in_shape, int[] begin, int[] strides, int begin_mask) + private long[] NormalizeStridedSliceBegin(long[] in_shape, long[] begin, long[] strides, int begin_mask) { - var new_shape = Enumerable.Range(0, strides.Length).ToArray(); + var new_shape = Enumerable.Range(0, strides.Length).ToArray().ToLongs(); for (var i = 0; i < new_shape.Length; i++) { var stride = strides[i]; diff --git a/src/Nncase.Passes/Rules/Neutral/FoldReshape.cs b/src/Nncase.Passes/Rules/Neutral/FoldReshape.cs index 4de3b7f3c..81f6b0c25 100644 --- a/src/Nncase.Passes/Rules/Neutral/FoldReshape.cs +++ b/src/Nncase.Passes/Rules/Neutral/FoldReshape.cs @@ -34,7 +34,7 @@ public sealed partial class FoldNopReshape : IRewriteRule private Expr? GetReplace(Expr input, TensorConst newShape) { - var newShapeArray = newShape.Value.ToArray(); + var newShapeArray = newShape.Value.ToArray(); if ((newShapeArray.Count(x => x == -1) == 1 && newShapeArray.Length == input.CheckedShape.Count && input.CheckedShape.Zip(newShapeArray).Count(t => t.Second != -1 && t.First.FixedValue == t.Second) == newShapeArray.Length - 1) || input.CheckedShape.ToValueArray().SequenceEqual(newShapeArray)) @@ -72,7 +72,7 @@ public sealed partial class FoldReshapeBinaryConstReshape : IRewriteRule public IPattern Pattern { get; } = IsReshape(IsSwappableBinary("binary", null, b => b.BinaryOp is BinaryOp.Add or BinaryOp.Mul, IsReshape(IsWildcard("input") with { TypePattern = HasFixedShape() }, IsTensorConst("unsqShape")), IsTensorConst("binaryConst")), IsTensorConst("sqShape")); - private Expr? GetReplace(Expr input, Binary binary, int[] unsqShape, TensorConst binaryConst, int[] sqShape) + private Expr? GetReplace(Expr input, Binary binary, long[] unsqShape, TensorConst binaryConst, long[] sqShape) { var inShape = input.CheckedShape.ToValueArray(); if (!(sqShape.SequenceEqual(inShape) && RulesUtility.FindSqueezeAxis(unsqShape, sqShape) is int axis && axis != -1 && ( diff --git a/src/Nncase.Passes/Rules/Neutral/MatMulToConv2D.cs b/src/Nncase.Passes/Rules/Neutral/MatMulToConv2D.cs index 5d72863c1..b36d97f96 100644 --- a/src/Nncase.Passes/Rules/Neutral/MatMulToConv2D.cs +++ b/src/Nncase.Passes/Rules/Neutral/MatMulToConv2D.cs @@ -56,8 +56,8 @@ public sealed partial class MatMulToConv2D : IRewriteRule if_reshape, w_reshape, Tensor.FromScalar(0.0f, w_shape[0].FixedValue), - Tensor.FromScalar(1, new[] { 2 }), - Tensor.FromScalar(0, new[] { 2, 2 }), + Tensor.FromScalar(1, [2]), + Tensor.FromScalar(0, [2, 2]), new int[] { 1, 1 }, PadMode.Constant, 1).InheritMetaData(matMulCall); @@ -101,8 +101,8 @@ public sealed partial class BroadcastMatMulToConv2D : IRewriteRule if_reshape, w_reshape, Tensor.FromScalar(0.0f, w_shape[0].FixedValue), - Tensor.FromScalar(1, new[] { 2 }), - Tensor.FromScalar(0, new[] { 2, 2 }), + Tensor.FromScalar(1, [2]), + Tensor.FromScalar(0, [2, 2]), new int[] { 1, 1 }, PadMode.Constant, 1).InheritMetaData(matMulCall); @@ -148,8 +148,8 @@ public sealed partial class BroadcastMatMul : IRewriteRule newOutputShape[^2] = aShape[^2].FixedValue; newOutputShape[^1] = bShape[^1].FixedValue; - var ifShape = new int[] { -1, aShape[^2].FixedValue, aShape[^1].FixedValue }; - var wShape = new int[] { -1, newBShape[^2], newBShape[^1] }; + var ifShape = new long[] { -1, aShape[^2].FixedValue, aShape[^1].FixedValue }; + var wShape = new long[] { -1, newBShape[^2], newBShape[^1] }; var bBroadCast = IR.F.Tensors.Broadcast(b, newBShape); List outputNames = new() { b.Metadata.OutputNames![0] + "_bBroadCast" }; bBroadCast.Metadata.OutputNames = outputNames; @@ -165,8 +165,8 @@ public sealed partial class BroadcastMatMul : IRewriteRule newOutputShape[^2] = aShape[^2].FixedValue; newOutputShape[^1] = bShape[^1].FixedValue; - var ifShape = new int[] { -1, newAShape[^2], newAShape[^1] }; - var wShape = new int[] { -1, bShape[^2].FixedValue, bShape[^1].FixedValue }; + var ifShape = new long[] { -1, newAShape[^2], newAShape[^1] }; + var wShape = new long[] { -1, bShape[^2].FixedValue, bShape[^1].FixedValue }; var aBroadCast = IR.F.Tensors.Broadcast(a, newAShape); List outputNames = new() { a.Metadata.OutputNames![0] + "_aBroadCast" }; aBroadCast.Metadata.OutputNames = outputNames; @@ -187,8 +187,8 @@ public sealed partial class BroadcastMatMul : IRewriteRule newOutputShape[i] = System.Math.Max(aShape[i].FixedValue, bShape[i].FixedValue); } - var ifShape = new int[] { -1, newAShape[^2], newAShape[^1] }; - var wShape = new int[] { -1, newBShape[^2], newBShape[^1] }; + var ifShape = new long[] { -1, newAShape[^2], newAShape[^1] }; + var wShape = new long[] { -1, newBShape[^2], newBShape[^1] }; var bBroadCast = IR.F.Tensors.Broadcast(b, newBShape); List bOutputNames = new() { b.Metadata.OutputNames?[0] + "_bBroadCast" }; bBroadCast.Metadata.OutputNames = bOutputNames; diff --git a/src/Nncase.Passes/Rules/Neutral/RemoveUnusedFunctions.cs b/src/Nncase.Passes/Rules/Neutral/RemoveUnusedFunctions.cs index 61790ddcb..1a7b00506 100644 --- a/src/Nncase.Passes/Rules/Neutral/RemoveUnusedFunctions.cs +++ b/src/Nncase.Passes/Rules/Neutral/RemoveUnusedFunctions.cs @@ -36,7 +36,7 @@ protected override Task RunCoreAsync(IRModule input, RunPassContext co foreach (var func in input.Functions) { if (!ReferenceEquals(func, input.Entry) - && func.Users.Count == 1) + && func.Users.Count() == 1) { funcsToRemove.Add(func); } diff --git a/src/Nncase.Passes/Rules/Neutral/RemoveUnusedVars.cs b/src/Nncase.Passes/Rules/Neutral/RemoveUnusedVars.cs index 206bfc2c7..651466f81 100644 --- a/src/Nncase.Passes/Rules/Neutral/RemoveUnusedVars.cs +++ b/src/Nncase.Passes/Rules/Neutral/RemoveUnusedVars.cs @@ -34,7 +34,7 @@ public sealed partial class RemoveUnusedVarsByCall : IRewriteRule for (int i = 0; i < function.Parameters.Length; i++) { var var = function.Parameters[i]; - if (var.Users.Count == 1) + if (var.Users.Count() == 1) { unusedVars++; } @@ -87,8 +87,8 @@ public sealed partial class RemoveUnusedVarsByIf : IRewriteRule { var thenVar = thenFunc.Parameters[i]; var elseVar = elseFunc.Parameters[i]; - if (thenVar.Users.Count == 1 - && elseVar.Users.Count == 1) + if (thenVar.Users.Count() == 1 + && elseVar.Users.Count() == 1) { unusedVars++; } diff --git a/src/Nncase.Passes/Rules/Neutral/ReshapeExpand.cs b/src/Nncase.Passes/Rules/Neutral/ReshapeExpand.cs index d5b06be59..712022c40 100644 --- a/src/Nncase.Passes/Rules/Neutral/ReshapeExpand.cs +++ b/src/Nncase.Passes/Rules/Neutral/ReshapeExpand.cs @@ -52,7 +52,8 @@ public partial class ReshapeExpand : RewriteRule newShape.Add(dim); } - if (newShape.Count == shape.Length) // No 1 exists + // No 1 exists + if (newShape.Count == shape.Length) { return null; } diff --git a/src/Nncase.Passes/Rules/Neutral/ReshapeMatMul.cs b/src/Nncase.Passes/Rules/Neutral/ReshapeMatMul.cs index 216c5f1de..61ba39f3d 100644 --- a/src/Nncase.Passes/Rules/Neutral/ReshapeMatMul.cs +++ b/src/Nncase.Passes/Rules/Neutral/ReshapeMatMul.cs @@ -33,7 +33,7 @@ public partial class ReshapeMatMul : RewriteRule var shapeA = a.CheckedShape.ToValueArray(); if (a.CheckedShape.Rank == 4) { - var c = shapeA.Take(a.CheckedShape.Rank - 2).Aggregate(1, (sum, x) => x * sum); + var c = shapeA.Take(a.CheckedShape.Rank - 2).Aggregate(1L, (sum, x) => x * sum); var newShapeA = new long[] { c, shapeA[^2], shapeA[^1] }; lhs = IR.F.Tensors.Reshape(a, newShapeA); } @@ -47,7 +47,7 @@ public partial class ReshapeMatMul : RewriteRule var shapeB = b.CheckedShape.ToValueArray(); if (b.CheckedShape.Rank == 4) { - var c = shapeB.Take(b.CheckedShape.Rank - 2).Aggregate(1, (sum, x) => x * sum); + var c = shapeB.Take(b.CheckedShape.Rank - 2).Aggregate(1L, (sum, x) => x * sum); var newShapeB = new long[] { c, shapeB[^2], shapeB[^1] }; rhs = IR.F.Tensors.Reshape(b, newShapeB); } diff --git a/src/Nncase.Passes/Rules/Neutral/ScalarConstToTensor.cs b/src/Nncase.Passes/Rules/Neutral/ScalarConstToTensor.cs index 848688cf7..15f877ed4 100644 --- a/src/Nncase.Passes/Rules/Neutral/ScalarConstToTensor.cs +++ b/src/Nncase.Passes/Rules/Neutral/ScalarConstToTensor.cs @@ -25,7 +25,7 @@ public partial class ScalarConstToTensor : RewriteRule { if (call.Arguments.AsValueEnumerable().Any(a => a is TensorConst { Value: Tensor { Shape.IsScalar: true } })) { - var arguments = call.Arguments.AsValueEnumerable().Select(e => e switch { TensorConst { Value: Tensor { Shape.IsScalar: true } } tc => Const.FromTensor(Tensor.FromBytes(tc.CheckedDataType, tc.Value.BytesBuffer.ToArray(), new[] { 1 })), _ => e }).ToArray(); + var arguments = call.Arguments.AsValueEnumerable().Select(e => e switch { TensorConst { Value: Tensor { Shape.IsScalar: true } } tc => Const.FromTensor(Tensor.FromBytes(tc.CheckedDataType, tc.Value.BytesBuffer.ToArray(), [1])), _ => e }).ToArray(); return call.With(arguments: arguments); } diff --git a/src/Nncase.Passes/Rules/Neutral/SpaceToBatchTransform.cs b/src/Nncase.Passes/Rules/Neutral/SpaceToBatchTransform.cs index 821ffba66..90452c0a5 100644 --- a/src/Nncase.Passes/Rules/Neutral/SpaceToBatchTransform.cs +++ b/src/Nncase.Passes/Rules/Neutral/SpaceToBatchTransform.cs @@ -48,7 +48,7 @@ public sealed partial class SpaceToBatchToPad : IRewriteRule newPaddingsArray[i + 4] = paddingsArray[i]; } - var newPaddings = Tensor.From(newPaddingsArray, new[] { 4, 2 }); + var newPaddings = Tensor.From(newPaddingsArray, [4, 2]); return Pad(input, newPaddings, PadMode.Constant, 0f); } diff --git a/src/Nncase.Passes/Rules/Neutral/SplitSpaceToBatch.cs b/src/Nncase.Passes/Rules/Neutral/SplitSpaceToBatch.cs index 8494e7242..eb028f06b 100644 --- a/src/Nncase.Passes/Rules/Neutral/SplitSpaceToBatch.cs +++ b/src/Nncase.Passes/Rules/Neutral/SplitSpaceToBatch.cs @@ -33,7 +33,7 @@ public partial class SplitSpaceToBatch : RewriteRule public Expr? GetReplace(Expr input, Expr blockShape, Expr paddings) { - var spatialSize = blockShape.CheckedShape.Size; + var spatialSize = (int)blockShape.CheckedShape.Size; var remainShapeSize = input.CheckedShape.Rank - spatialSize - 1; var newPaddings = Enumerable.Repeat((Expr)0, (1 + spatialSize + remainShapeSize) * 2).ToArray(); for (int i = 0; i < spatialSize; i++) @@ -104,7 +104,7 @@ public partial class SplitBatchToSpace : RewriteRule { // to nhwc var input0 = NCHWToNHWC(input); - var blockLen = blockShape.CheckedShape.Size; + var blockLen = (int)blockShape.CheckedShape.Size; var xLen = input0.CheckedShape.Rank; var xShape = Cast(ShapeOf(input0), DataTypes.Int32); var spatial = ShapeExprUtility.Slice(xShape, 1, blockLen + 1); @@ -144,7 +144,7 @@ public partial class SplitBatchToSpace : RewriteRule return transposeResult; } - private static IEnumerable BoostRange(int start, int end, int step = 1) + private static IEnumerable BoostRange(int start, int end, int step = 1) { int x = start; do diff --git a/src/Nncase.Passes/Rules/Neutral/SqueezeShape.cs b/src/Nncase.Passes/Rules/Neutral/SqueezeShape.cs index f3a50c126..fc3638a79 100644 --- a/src/Nncase.Passes/Rules/Neutral/SqueezeShape.cs +++ b/src/Nncase.Passes/Rules/Neutral/SqueezeShape.cs @@ -47,7 +47,7 @@ public sealed partial class Squeeze5DTranspose : IRewriteRule var shape2 = perm1.Select(p => shape1[p]).ToArray(); int[] perm2; - int[] shape3; + long[] shape3; switch (perm1.IndexOf(3)) { case 0: @@ -181,15 +181,15 @@ public sealed partial class SqueezeTransposeShape : IRewriteRule IsWildcard("input") with { TypePattern = HasFixedShape() & HasRank(x => x > 4, "more than 4D need to squeeze") }, IsWildcard("perm")); - private Tuple, List> SqueezeTranspose(List oldShape, List oldAxis) + private Tuple, List> SqueezeTranspose(List oldShape, List oldAxis) { if (oldShape.Count <= 4) { - return new Tuple, List>(false, oldAxis, oldShape); + return new Tuple, List>(false, oldAxis, oldShape); } var newAxis = new List(oldAxis); - var newShape = new List(oldShape); + var newShape = new List(oldShape); int squeezeTimes = oldShape.Count - 4; var foldIndexCouple = new List>(); @@ -203,7 +203,7 @@ private Tuple, List> SqueezeTranspose(List oldShape, L if (foldIndexCouple.Count < squeezeTimes) { - return new Tuple, List>(false, newAxis, newShape); + return new Tuple, List>(false, newAxis, newShape); } while (squeezeTimes > 0 && foldIndexCouple.Count > 0) @@ -229,7 +229,7 @@ private Tuple, List> SqueezeTranspose(List oldShape, L } } - return new Tuple, List>(true, newAxis, newShape); + return new Tuple, List>(true, newAxis, newShape); } private Expr? GetReplace(Expr input, int[] perm, Expr call) @@ -241,7 +241,7 @@ private Tuple, List> SqueezeTranspose(List oldShape, L return null; } - var newOutputShape = new int[perm.Length]; + var newOutputShape = new long[perm.Length]; for (int i = 0; i < perm.Length; i++) { newOutputShape[i] = inputShape[perm[i]].FixedValue; @@ -263,7 +263,7 @@ public sealed partial class SqueezeBinaryShape : IRewriteRule /// left input shape. /// right input shape. /// Squeeze flag, new lhs, new rhs. - public (bool SqueezeOrNot, List NewAShape, List NewBShape) SqueezeInputShape(List a, List b) + public (bool SqueezeOrNot, List NewAShape, List NewBShape) SqueezeInputShape(List a, List b) { var aSize = a.Count; var bSize = b.Count; @@ -277,8 +277,8 @@ public sealed partial class SqueezeBinaryShape : IRewriteRule return (false, a, b); } - List newA = a; - List newB = b; + List newA = a; + List newB = b; if (aSize == bSize) { @@ -353,9 +353,9 @@ public sealed partial class SqueezeBinaryShape : IRewriteRule return (true, newA, newB); } - private static List SqueezeShape(List shape) + private static List SqueezeShape(List shape) { - var newShape = new List { 1, 1, 1, 1 }; + var newShape = new List { 1, 1, 1, 1 }; for (int i = shape.Count - 1, k = 3; i >= 0; i--) { @@ -369,7 +369,7 @@ private static List SqueezeShape(List shape) return newShape; } - private static List GetOutputShape(List a, List b) + private static List GetOutputShape(List a, List b) { if (a.Count == 1) { diff --git a/src/Nncase.Passes/Rules/ShapeBucket/MergeBucketFusion.cs b/src/Nncase.Passes/Rules/ShapeBucket/MergeBucketFusion.cs index 2b02fae6c..bddf4e785 100644 --- a/src/Nncase.Passes/Rules/ShapeBucket/MergeBucketFusion.cs +++ b/src/Nncase.Passes/Rules/ShapeBucket/MergeBucketFusion.cs @@ -24,13 +24,6 @@ namespace Nncase.Passes.Rules.ShapeBucket; #if true -internal sealed class EffectVarEqualityComparer : IEqualityComparer -{ - public bool Equals(Var[]? x, Var[]? y) => System.Collections.StructuralComparisons.StructuralEqualityComparer.Equals(x, y); - - public int GetHashCode([DisallowNull] Var[] obj) => System.Collections.StructuralComparisons.StructuralEqualityComparer.GetHashCode(obj); -} - public sealed class MergeBucketFusionPass : FunctionPass { protected override Task RunCoreAsync(BaseFunction baseFunction, RunPassContext context) @@ -146,11 +139,18 @@ private Function Perform(Function pre, HashSet effectVarSet) } } - var post = pre.With(pre.Name, exprMemo[pre.Body], pre.Parameters.ToArray()); + var post = pre.With(name: pre.Name, body: exprMemo[pre.Body], parameters: pre.Parameters.ToArray()); return post; } } +internal sealed class EffectVarEqualityComparer : IEqualityComparer +{ + public bool Equals(Var[]? x, Var[]? y) => System.Collections.StructuralComparisons.StructuralEqualityComparer.Equals(x, y); + + public int GetHashCode([DisallowNull] Var[] obj) => System.Collections.StructuralComparisons.StructuralEqualityComparer.GetHashCode(obj); +} + internal sealed class SubFusionCloner : ExprCloner { private readonly Dictionary _feedDict; diff --git a/src/Nncase.Passes/Rules/ShapeBucket/MergeCallToFusion.cs b/src/Nncase.Passes/Rules/ShapeBucket/MergeCallToFusion.cs index d3d1ad0d5..436d7113f 100644 --- a/src/Nncase.Passes/Rules/ShapeBucket/MergeCallToFusion.cs +++ b/src/Nncase.Passes/Rules/ShapeBucket/MergeCallToFusion.cs @@ -456,7 +456,9 @@ private static Expr[] MakeNewPrevCalls(Call[] inputsShouldBeMerge, Expr[] prevOu } Expr canditateVar = newVar.First(); - if (x is Marker mm && mm.Attribute is TensorConst) // Const range of marker + + // Const range of marker + if (x is Marker mm && mm.Attribute is TensorConst) { canditateVar = mm.With(target: canditateVar); } diff --git a/src/Nncase.Passes/Rules/ShapeBucket/ShapeBucket.cs b/src/Nncase.Passes/Rules/ShapeBucket/ShapeBucket.cs index 187d6fe45..98805ca78 100644 --- a/src/Nncase.Passes/Rules/ShapeBucket/ShapeBucket.cs +++ b/src/Nncase.Passes/Rules/ShapeBucket/ShapeBucket.cs @@ -1035,6 +1035,7 @@ public static Function MakeSplitEntry(FusionBucketContext context, Dictionary(true); // p.Add(true); // }); - MergeFusion(p, singleVar, false); } diff --git a/src/Nncase.Passes/Rules/ShapeBucket/SplitLLMStage.cs b/src/Nncase.Passes/Rules/ShapeBucket/SplitLLMStage.cs index 2d2bce5b2..108b99b4f 100644 --- a/src/Nncase.Passes/Rules/ShapeBucket/SplitLLMStage.cs +++ b/src/Nncase.Passes/Rules/ShapeBucket/SplitLLMStage.cs @@ -128,8 +128,10 @@ private static void RegisterBucketPass(IPassManager p, bool singleVar) LostToFusion(p, singleVar); MergeOp(p, true); + // ClearMarker(p); MergeFusion(p, singleVar, true); + // Rebuild(p, singleVar); Bucket(p); Simplify(p); diff --git a/src/Nncase.Passes/Rules/ShapeExpr/FoldSplitShapeOf.cs b/src/Nncase.Passes/Rules/ShapeExpr/FoldSplitShapeOf.cs index dd530a7c5..0f3ab0e0d 100644 --- a/src/Nncase.Passes/Rules/ShapeExpr/FoldSplitShapeOf.cs +++ b/src/Nncase.Passes/Rules/ShapeExpr/FoldSplitShapeOf.cs @@ -56,7 +56,7 @@ public partial class FoldSplitShapeOf : RewriteRule return null; } - if (getItemIndices.SequenceEqual(Enumerable.Range(0, shapeOf.CheckedShape[0].FixedValue))) + if (getItemIndices.SequenceEqual(Enumerable.Range(0, (int)shapeOf.CheckedShape[0].FixedValue))) { return shapeOf; } diff --git a/src/Nncase.Passes/Rules/WithMarker/CombineReshapePad.cs b/src/Nncase.Passes/Rules/WithMarker/CombineReshapePad.cs index e6a36485d..a68347f77 100644 --- a/src/Nncase.Passes/Rules/WithMarker/CombineReshapePad.cs +++ b/src/Nncase.Passes/Rules/WithMarker/CombineReshapePad.cs @@ -53,11 +53,11 @@ public sealed partial class CombineReshapePad : IRewriteRule var newPad = Pad( marker.With(target: Reshape( marker.With(target: input), - Enumerable.Repeat(1, reshapeRank - padRank).Concat(input.CheckedShape.ToValueArray()).ToArray()) + Enumerable.Repeat(1L, reshapeRank - padRank).Concat(input.CheckedShape.ToValueArray()).ToArray()) .InheritMetaData(reshapeCall)), Tensor.From( Enumerable.Repeat(0, (reshapeRank - padRank) * 2).Concat(pads).ToArray(), - new[] { reshapeRank, 2 }), + [reshapeRank, 2]), pad.PadMode, value).InheritMetaData(padCall); var outMarker = result.GetValueOrDefault("outMarker"); diff --git a/src/Nncase.Passes/Rules/WithMarker/FoldConv2DBiasWithMarker.cs b/src/Nncase.Passes/Rules/WithMarker/FoldConv2DBiasWithMarker.cs index 6fcf83b91..339365416 100644 --- a/src/Nncase.Passes/Rules/WithMarker/FoldConv2DBiasWithMarker.cs +++ b/src/Nncase.Passes/Rules/WithMarker/FoldConv2DBiasWithMarker.cs @@ -30,8 +30,6 @@ namespace Nncase.Passes.Rules.Neutral; [RuleGenerator] public sealed partial class FoldConv2DBiasWithMarker : IRewriteRule { - private static int _counter; - /// public IPattern Pattern { get; } = IsRangeOfMarker( "binarym", diff --git a/src/Nncase.Passes/Rules/WithMarker/MatMulToConv2DWithMarker.cs b/src/Nncase.Passes/Rules/WithMarker/MatMulToConv2DWithMarker.cs index 3d8fa2ce9..ba8659743 100644 --- a/src/Nncase.Passes/Rules/WithMarker/MatMulToConv2DWithMarker.cs +++ b/src/Nncase.Passes/Rules/WithMarker/MatMulToConv2DWithMarker.cs @@ -44,7 +44,7 @@ public sealed partial class MatMulToConv2DWithMarker : IRewriteRule { var aShape = a.CheckedShape; var bShape = b.CheckedShape; - if (aShape.Count > 2 && aShape.ToValueArray()[..^2].Aggregate(1, (sum, x) => sum * x) != 1) + if (aShape.Count > 2 && aShape.ToValueArray()[..^2].Aggregate(1L, (sum, x) => sum * x) != 1) { return null; } @@ -65,8 +65,8 @@ public sealed partial class MatMulToConv2DWithMarker : IRewriteRule am.With(target: if_reshape), bm.With(target: w_reshape), Tensor.FromScalar(0.0f, w_shape[0].FixedValue), - Tensor.FromScalar(1, new[] { 2 }), - Tensor.FromScalar(0, new[] { 2, 2 }), + Tensor.FromScalar(1, [2]), + Tensor.FromScalar(0, [2, 2]), new int[] { 1, 1 }, PadMode.Constant, 1).InheritMetaData(matMulCall); @@ -115,8 +115,8 @@ public sealed partial class BroadcastMatMulToConv2DWithMarker : IRewriteRule am.With(target: if_reshape), bm.With(target: w_reshape), Tensor.FromScalar(0.0f, w_shape[0].FixedValue), - Tensor.FromScalar(1, new[] { 2 }), - Tensor.FromScalar(0, new[] { 2, 2 }), + Tensor.FromScalar(1, [2]), + Tensor.FromScalar(0, [2, 2]), new int[] { 1, 1 }, PadMode.Constant, 1).InheritMetaData(matMulCall); diff --git a/src/Nncase.Passes/RulesUtility.cs b/src/Nncase.Passes/RulesUtility.cs index 29afea146..2b15846c5 100644 --- a/src/Nncase.Passes/RulesUtility.cs +++ b/src/Nncase.Passes/RulesUtility.cs @@ -13,7 +13,7 @@ public static class RulesUtility /// old shape. /// new shape. /// axis, if not found return -1. - public static int FindSqueezeAxis(int[] oldShape, int[] newShape) + public static int FindSqueezeAxis(long[] oldShape, long[] newShape) { if (oldShape.Length <= newShape.Length) { diff --git a/src/Nncase.Passes/SimplifyProvider.cs b/src/Nncase.Passes/SimplifyProvider.cs new file mode 100644 index 000000000..6107d9873 --- /dev/null +++ b/src/Nncase.Passes/SimplifyProvider.cs @@ -0,0 +1,73 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.CommandLine; +using System.CommandLine.Invocation; +using System.Threading.Tasks; +using Microsoft.Extensions.Configuration; +using Nncase.CodeGen; +using Nncase.Diagnostics; +using Nncase.IR; +using Nncase.Passes.Rules.Neutral; +using Nncase.Quantization; + +namespace Nncase.Passes; + +internal sealed class SimplifyTarget : ITarget +{ + public string Kind => "Simplify"; + + public Task AdaRoundWeights(ICalibrationDatasetProvider calibrationDataset, List rangeOfs, List childrenOfRangeOfs, QuantizeOptions quantizeOptions) => throw new NotImplementedException(); + + public Task, List>, float>>>> BindQuantMethodCosine(ICalibrationDatasetProvider calibrationDataset, List rangeOfs, List childrenOfRangeOfs, QuantizeOptions quantizeOptions) => throw new NotImplementedException(); + + public IModuleBuilder CreateModuleBuilder(string moduleKind, CompileOptions options) => throw new NotImplementedException(); + + public void ParseTargetDependentOptions(IConfigurationSection configure) => throw new NotImplementedException(); + + public (Command Command, Func Parser) RegisterCommandAndParser() => throw new NotImplementedException(); + + public void RegisterQuantizePass(IPassManager passManager, CompileOptions options) => throw new NotImplementedException(); + + public void RegisterTargetDependentAfterQuantPass(IPassManager passManager, CompileOptions options) => throw new NotImplementedException(); + + public void RegisterTargetDependentBeforeCodeGen(IPassManager passManager, CompileOptions options) => throw new NotImplementedException(); + + public void RegisterTargetDependentPass(IPassManager passManager, CompileOptions options) => throw new NotImplementedException(); + + public void RegisterTargetInDependentPass(IPassManager passManager, CompileOptions options) => throw new NotImplementedException(); +} + +internal sealed class SimplifyProvider : ISimplifyProvider +{ + private readonly CompileSession _compileSession; + private readonly IRewriteRule[] _rules; + + public SimplifyProvider() + { + _compileSession = CompileSession.Create(new SimplifyTarget(), new CompileOptions()); + using var compileScope = new CompileSessionScope(_compileSession); + _rules = [ + new FoldConstCall(), + ]; + } + + public Expr SimplifyForDimension(Expr expr) + { +#if false + if (expr is not (Const or Var)) + { + using var compileScope = new CompileSessionScope(CompileSessionScope.Current ?? _compileSession); + using var dumpScope = new DumpScope(NullDumpper.Instance); + expr = CompilerServices.Rewrite(expr, _rules, new RunPassContext()); + } + + return expr; +#else + return expr; +#endif + } +} diff --git a/src/Nncase.Quantization/Quantization/CalibrationEvaluator.cs b/src/Nncase.Quantization/Quantization/CalibrationEvaluator.cs index 0a6dfef18..a917d1a37 100644 --- a/src/Nncase.Quantization/Quantization/CalibrationEvaluator.cs +++ b/src/Nncase.Quantization/Quantization/CalibrationEvaluator.cs @@ -260,7 +260,7 @@ private IValue VisitLeaf(ENode enode, Func valueGetter) var valueArray = value.AsTensor().ToArray(); int index = 0; - int size = 0; + long size = 0; if (((Marker)enode.Expr).MixQuantInfo!.QuantParameter.Count != 1) { size = value.AsTensor().Shape[1].FixedValue * value.AsTensor().Shape[2].FixedValue * value.AsTensor().Shape[3].FixedValue; @@ -270,7 +270,7 @@ private IValue VisitLeaf(ENode enode, Func valueGetter) { if (((Marker)enode.Expr).MixQuantInfo!.QuantParameter.Count != 1) { - index = i / size; + index = checked(i / (int)size); } var valueArrayQuant = Math.Round((valueArray[i] / (double)((Marker)enode.Expr).MixQuantInfo!.QuantParameter[index].Scale) + ((Marker)enode.Expr).MixQuantInfo!.QuantParameter[index].ZeroPoint); diff --git a/src/Nncase.Quantization/Quantization/PytestCalibrationDatasetProvider.cs b/src/Nncase.Quantization/Quantization/PytestCalibrationDatasetProvider.cs index 6aef026d8..d85d80661 100644 --- a/src/Nncase.Quantization/Quantization/PytestCalibrationDatasetProvider.cs +++ b/src/Nncase.Quantization/Quantization/PytestCalibrationDatasetProvider.cs @@ -58,7 +58,7 @@ public PytestCalibrationDatasetProvider(IReadOnlyList vars, string dataset) { case TensorType tensorType: { - int[] shape = Array.Empty(); + long[] shape = Array.Empty(); if (tensorType.Shape.IsFixed) { shape = tensorType.Shape.ToValueArray(); @@ -113,17 +113,17 @@ private sealed record Sample(string Name, int Number, int InputIndex) { public string FileName => $"{Name}_{InputIndex}_{Number}.bin"; - public int[] GetShape() + public long[] GetShape() { using var stream = File.OpenRead($"{Name}_{InputIndex}_{Number}.txt"); using var reader = new StreamReader(stream); var line = reader.ReadLine(); - int[] shape = Array.Empty(); + long[] shape = Array.Empty(); if (line is string shapeString) { string pattern = @"\d+"; MatchCollection matches = Regex.Matches(shapeString, pattern); - shape = matches.Select(m => int.Parse(m.Value)).ToArray(); + shape = matches.Select(m => long.Parse(m.Value)).ToArray(); } return shape; diff --git a/src/Nncase.Quantization/Quantization/QuantUtility.cs b/src/Nncase.Quantization/Quantization/QuantUtility.cs index 3a5dcef39..9b9230b92 100644 --- a/src/Nncase.Quantization/Quantization/QuantUtility.cs +++ b/src/Nncase.Quantization/Quantization/QuantUtility.cs @@ -21,7 +21,7 @@ namespace Nncase.Quantization; /// public static class QuantAlgorithmUtility { - public static Tensor SquantWeights(Tensor inputWeights, Tensor inputWeightsRanges, ReadOnlySpan inputWeightsShape, QuantMode quantMode, int bits, bool isByChannel) + public static Tensor SquantWeights(Tensor inputWeights, Tensor inputWeightsRanges, ReadOnlySpan inputWeightsShape, QuantMode quantMode, int bits, bool isByChannel) { float qMax, qMin; if (quantMode == QuantMode.UnsignedMode) @@ -51,7 +51,7 @@ public static Tensor SquantWeights(Tensor inputWeights, Tensor { @@ -83,7 +83,7 @@ public static Tensor SquantWeights(Tensor inputWeights, Tensor { diff --git a/src/Nncase.Quantization/Quantization/Quantizer.Algorithms.cs b/src/Nncase.Quantization/Quantization/Quantizer.Algorithms.cs index c59723369..e760d80dd 100644 --- a/src/Nncase.Quantization/Quantization/Quantizer.Algorithms.cs +++ b/src/Nncase.Quantization/Quantization/Quantizer.Algorithms.cs @@ -7,7 +7,6 @@ using System.Text; using System.Threading.Tasks; using Nncase.IR; -using Nncase.IR.Tensors; using Nncase.IR.F; using Nncase.TIR; using Math = System.Math; diff --git a/src/Nncase.Quantization/Quantization/Quantizer.cs b/src/Nncase.Quantization/Quantization/Quantizer.cs index cca22d01b..a9ca565a4 100644 --- a/src/Nncase.Quantization/Quantization/Quantizer.cs +++ b/src/Nncase.Quantization/Quantization/Quantizer.cs @@ -761,7 +761,7 @@ private void AssignByChannelRanges(IDictionary[]> range minMaxArr[i] = i % 2 == 0 ? value[i / 2].Min : value[i / 2].Max; } - var shape = oc == 1 ? new[] { 2 } : new[] { oc, 2 }; + long[] shape = oc == 1 ? [2] : [oc, 2]; var rangeEclass = _graph.Add(new TensorConst(Tensor.From(minMaxArr, shape))); var rangeOfEclass = _graph.Find(range.Key); range.Key.Expr.CheckedType = rangeEclass.CheckedType; diff --git a/src/Nncase.Schedule/Schedule/AffineTiler.cs b/src/Nncase.Schedule/Schedule/AffineTiler.cs index 794dbc369..82aed3128 100644 --- a/src/Nncase.Schedule/Schedule/AffineTiler.cs +++ b/src/Nncase.Schedule/Schedule/AffineTiler.cs @@ -19,14 +19,14 @@ namespace Nncase.Schedule; -internal sealed record AffineTilerMemo(IRArray BufferShapes, IRArray DomainBounds, IRArray AffineMaps, Type OpType, int ElemSize) +internal sealed record AffineTilerMemo(IRArray BufferShapes, IRArray DomainBounds, IRArray AffineMaps, Type OpType, int ElemSize) { } internal sealed class AffineTiler { private readonly Grid _grid; - private readonly int[] _domainBounds; + private readonly long[] _domainBounds; private readonly ILogger _logger; public AffineTiler(Grid grid, ITargetOptions targetOptions, ILoggerFactory loggerFactory) diff --git a/src/Nncase.Schedule/Schedule/TileGraph/TileGraphTypes.cs b/src/Nncase.Schedule/Schedule/TileGraph/TileGraphTypes.cs index 00598568b..0305fa495 100644 --- a/src/Nncase.Schedule/Schedule/TileGraph/TileGraphTypes.cs +++ b/src/Nncase.Schedule/Schedule/TileGraph/TileGraphTypes.cs @@ -45,7 +45,7 @@ public DomainRelation ApplyRange(DomainRelation other) public sealed class TileGrid : ITileable { - public TileGrid(Grid grid, Op op, int opId, IEnumerable dimNames, IEnumerable domainBounds, IEnumerable> bufferShapes) + public TileGrid(Grid grid, Op op, int opId, IEnumerable dimNames, IEnumerable domainBounds, IEnumerable> bufferShapes) { Level = 0; Grid = grid; @@ -66,9 +66,9 @@ public TileGrid(Grid grid, Op op, int opId, IEnumerable dimNames, IEnume public Op Op { get; } - public ImmutableArray DomainBounds { get; } + public ImmutableArray DomainBounds { get; } - public ImmutableArray> BufferShapes { get; } + public ImmutableArray> BufferShapes { get; } public ReadOnlySpan ReadAccesses => Grid.AccessMaps[..^1]; diff --git a/src/Nncase.Schedule/Schedule/TileGraph/TileTreeTypes.cs b/src/Nncase.Schedule/Schedule/TileGraph/TileTreeTypes.cs index 53ba7d6b5..622545e9c 100644 --- a/src/Nncase.Schedule/Schedule/TileGraph/TileTreeTypes.cs +++ b/src/Nncase.Schedule/Schedule/TileGraph/TileTreeTypes.cs @@ -54,9 +54,9 @@ public OpNode(ITreeNode? parent, TileGrid wrapped) public Op Op => _wrapped.Op; - public ImmutableArray DomainBounds => _wrapped.DomainBounds; + public ImmutableArray DomainBounds => _wrapped.DomainBounds; - public ImmutableArray> BufferShapes => _wrapped.BufferShapes; + public ImmutableArray> BufferShapes => _wrapped.BufferShapes; public ReadOnlySpan ReadAccesses => _wrapped.ReadAccesses; diff --git a/src/Nncase.Schedule/Schedule/TileGraph/TreeSolverPythonPrinter.cs b/src/Nncase.Schedule/Schedule/TileGraph/TreeSolverPythonPrinter.cs index 44075107d..7653549a8 100644 --- a/src/Nncase.Schedule/Schedule/TileGraph/TreeSolverPythonPrinter.cs +++ b/src/Nncase.Schedule/Schedule/TileGraph/TreeSolverPythonPrinter.cs @@ -32,7 +32,7 @@ public Unit Visit(TileNode value, (ITreeNode? Parent, IndentedTextWriter Writer) var trip = Solution.Value(TileNodeMemo[value].TripCounts[i + 1].Var()); // 2. write loop. - int parentBounds = 0; + long parentBounds = 0; if (parent is null) { value.Walk(child => diff --git a/src/Nncase.Schedule/Schedule/TileTree/TileTreeTypes.cs b/src/Nncase.Schedule/Schedule/TileTree/TileTreeTypes.cs index 7eb7a44fe..35c2bdf81 100644 --- a/src/Nncase.Schedule/Schedule/TileTree/TileTreeTypes.cs +++ b/src/Nncase.Schedule/Schedule/TileTree/TileTreeTypes.cs @@ -153,7 +153,7 @@ public override string ToString() public sealed class OpNode : ITileAbleNode { - public OpNode(Grid grid, Op op, int opId, IEnumerable dimNames, IEnumerable domainBounds, IEnumerable> bufferShapes, IEnumerable dependences) + public OpNode(Grid grid, Op op, int opId, IEnumerable dimNames, IEnumerable domainBounds, IEnumerable> bufferShapes, IEnumerable dependences) { Level = 0; Grid = grid; @@ -188,9 +188,9 @@ public OpNode(Grid grid, Op op, int opId, IEnumerable dimNames, IEnumera public ImmutableArray Dependences { get; } - public ImmutableArray DomainBounds { get; } + public ImmutableArray DomainBounds { get; } - public ImmutableArray> BufferShapes { get; } + public ImmutableArray> BufferShapes { get; } public ReadOnlySpan ReadAccesses => Grid.AccessMaps[..^1]; diff --git a/src/Nncase.Schedule/Schedule/TileTree/TreeCloner.cs b/src/Nncase.Schedule/Schedule/TileTree/TreeCloner.cs index ef1b55c77..9d926c8b0 100644 --- a/src/Nncase.Schedule/Schedule/TileTree/TreeCloner.cs +++ b/src/Nncase.Schedule/Schedule/TileTree/TreeCloner.cs @@ -44,7 +44,7 @@ public ITreeNode Visit(OpNode value, Unit arg1) { if (!_memo.TryGetValue(value, out var nOp)) { - nOp = new OpNode(value.Grid, value.Op, value.OpId, value.DimNames, value.DomainBounds, value.BufferShapes.Select(x => (IEnumerable)x), value.Dependences.Select(d => new OpNode.Dependence(d.Index, (OpNode)_memo[d.Node]))) + nOp = new OpNode(value.Grid, value.Op, value.OpId, value.DimNames, value.DomainBounds, value.BufferShapes.Select(x => (IEnumerable)x), value.Dependences.Select(d => new OpNode.Dependence(d.Index, (OpNode)_memo[d.Node]))) { DomainRelation = value.DomainRelation, }; diff --git a/src/Nncase.Schedule/Schedule/TileTree/TreeSolverInitializer.cs b/src/Nncase.Schedule/Schedule/TileTree/TreeSolverInitializer.cs index 922c8065f..bdae7ee3c 100644 --- a/src/Nncase.Schedule/Schedule/TileTree/TreeSolverInitializer.cs +++ b/src/Nncase.Schedule/Schedule/TileTree/TreeSolverInitializer.cs @@ -231,7 +231,7 @@ public InitResult Visit(OpNode value, Context context) var tileVars = value.DimNames.Select(n => Solver.MakeIntVar(1, long.MaxValue, $"{n}_L{value.Level}")).ToArray(); // CompilerServices.GetOpMicroKernelInfo(value.Op, value.AccessMaps[0].Domains.AsValueEnumerable().Select(i => i.Offset).ToArray(), value.AccessMaps.ToArray(), value.BufferShapes, TargetOptions); - var kernelInfo = new MicroKernelInfo(tileVars.Select(i => 1).ToArray(), tileVars.Select((_, i) => new ValueRange(0, value.DomainBounds[i])).ToArray(), Array.Empty(), (_, b, _) => b.MakeIntConst(1)); + var kernelInfo = new MicroKernelInfo(tileVars.Select(i => 1).ToArray(), tileVars.Select((_, i) => new ValueRange(0, value.DomainBounds[i])).ToArray(), Array.Empty(), (_, b, _) => b.MakeIntConst(1)); for (int i = 0; i < tileVars.Length; i++) { diff --git a/src/Nncase.Schedule/Schedule/TilingSolver.cs b/src/Nncase.Schedule/Schedule/TilingSolver.cs index b166ef8df..36afacc41 100644 --- a/src/Nncase.Schedule/Schedule/TilingSolver.cs +++ b/src/Nncase.Schedule/Schedule/TilingSolver.cs @@ -26,7 +26,7 @@ public TilingSolver(ITargetOptions targetOptions) public ITargetOptions TargetOptions { get; } - public GridSchedule Solve(int[] domainBounds, int[][] bufferShapes, AffineDim[] domain, AffineMap[] accessMaps, Op computation, int elemSize) + public GridSchedule Solve(long[] domainBounds, long[][] bufferShapes, AffineDim[] domain, AffineMap[] accessMaps, Op computation, int elemSize) { int[] memoryCapacitys = new[] { 512 * 1024, int.MaxValue }; int[] memoryBandWidths = new[] { 128, 4 }; @@ -90,7 +90,7 @@ private static LoopMasks GetLoopMasks(AffineMap map) return new(masks); } - private GridSchedule? SolveWithPermutation(int[] domainBounds, int[][] bufferShapes, AffineDim[,] fullDomain, AffineMap[] accessMaps, LoopMasks[] loopMasks, int[] memoryCapacitys, int[] memoryBandWidths, string prefix, ref long bestObjective, Op computation, int elemSize) + private GridSchedule? SolveWithPermutation(long[] domainBounds, long[][] bufferShapes, AffineDim[,] fullDomain, AffineMap[] accessMaps, LoopMasks[] loopMasks, int[] memoryCapacitys, int[] memoryBandWidths, string prefix, ref long bestObjective, Op computation, int elemSize) { var totalLevel = memoryCapacitys.Length; var model = new Solver("tiling"); diff --git a/src/Nncase.Schedule/Schedule/TilingUtilities.cs b/src/Nncase.Schedule/Schedule/TilingUtilities.cs index dc0361ff7..65d3bd680 100644 --- a/src/Nncase.Schedule/Schedule/TilingUtilities.cs +++ b/src/Nncase.Schedule/Schedule/TilingUtilities.cs @@ -20,7 +20,7 @@ public static Expr GetUninitialized(Expr expr) }; } - public static int[] GetBufferShape(Expr buffer) + public static long[] GetBufferShape(Expr buffer) { return buffer.CheckedType switch { @@ -30,7 +30,7 @@ public static int[] GetBufferShape(Expr buffer) }; } - public static int[] InferDomainBounds(int[][] bufferShapes, AffineMap[] accessMaps) + public static long[] InferDomainBounds(long[][] bufferShapes, AffineMap[] accessMaps) { var solver = new Solver("affineSolver"); var converter = new AffineExprToIntExprConverter(solver); @@ -57,7 +57,7 @@ public static int[] InferDomainBounds(int[][] bufferShapes, AffineMap[] accessMa throw new InvalidOperationException(); } - var dims = dimVars.Select(x => (int)solutionCollector.Value(0, x)).ToArray(); + var dims = dimVars.Select(x => solutionCollector.Value(0, x)).ToArray(); return dims; } } diff --git a/src/Nncase.Schedule/Transforms/AutoTilePass.cs b/src/Nncase.Schedule/Transforms/AutoTilePass.cs index 44b976b4f..ced0d35a3 100644 --- a/src/Nncase.Schedule/Transforms/AutoTilePass.cs +++ b/src/Nncase.Schedule/Transforms/AutoTilePass.cs @@ -53,9 +53,9 @@ private BaseFunction Rewrite(BaseFunction pre, int funcNumber, GraphTiler tiler) var ctx = new GraphContext(); var convertor = new GraphConvertor(x => x switch { - Grid => true, - IR.Tuple tp => tp.Fields.AsValueEnumerable().All(f => f is Grid), - _ => false, + Grid => Compat.COMPATIBLE, + IR.Tuple tp => tp.Fields.AsValueEnumerable().All(f => f is Grid) ? Compat.COMPATIBLE : Compat.INCOMPATIBLE, + _ => Compat.INCOMPATIBLE, }); convertor.Visit(fusion.Body, ctx); diff --git a/src/Nncase.Simulator/Runtime/Interop/RTTensor.cs b/src/Nncase.Simulator/Runtime/Interop/RTTensor.cs index df582f6be..350b423cd 100644 --- a/src/Nncase.Simulator/Runtime/Interop/RTTensor.cs +++ b/src/Nncase.Simulator/Runtime/Interop/RTTensor.cs @@ -196,8 +196,8 @@ public static unsafe RTTensor FromTensor(Tensor tensor) tensor.BytesBuffer.CopyTo(mem.Memory.Span); } - var dims = MemoryMarshal.Cast(tensor.Dimensions); - var strides = MemoryMarshal.Cast(tensor.Strides); + var dims = MemoryMarshal.Cast(tensor.Dimensions.ToInts()); + var strides = MemoryMarshal.Cast(tensor.Strides.ToInts()); return Create(RTDataType.FromTypeCode(dtype.TypeCode), dims, strides, new RTBufferSlice { Buffer = buffer, Start = 0, SizeBytes = sizeBytes }); } diff --git a/src/Nncase.Studio/ViewModels/SimulateViewModel.cs b/src/Nncase.Studio/ViewModels/SimulateViewModel.cs index 2f6b0521c..48d01f5a9 100644 --- a/src/Nncase.Studio/ViewModels/SimulateViewModel.cs +++ b/src/Nncase.Studio/ViewModels/SimulateViewModel.cs @@ -229,7 +229,7 @@ private void SaveResult(Tensor[] result) var list = result .Select(t => np.frombuffer(t.BytesBuffer.ToArray(), t.ElementType.CLRType) - .reshape(t.Shape.ToValueArray())) + .reshape(t.Shape.ToValueArray().ToInts())) .ToArray(); for (int i = 0; i < list.Length; i++) diff --git a/src/Nncase.Tests.TestFixture/TestingServices.cs b/src/Nncase.Tests.TestFixture/TestingServices.cs index c4acf12d6..60001626d 100644 --- a/src/Nncase.Tests.TestFixture/TestingServices.cs +++ b/src/Nncase.Tests.TestFixture/TestingServices.cs @@ -75,7 +75,7 @@ public static ValueRange FixupRange(ValueRange range, bool symmetr /// /// create the rand value by gived datatype. /// - public static Tensor Rand(DataType dataType, params int[] shape) + public static Tensor Rand(DataType dataType, params long[] shape) { return IR.F.Random.Normal(dataType, 0, 1, 1, shape).Evaluate().AsTensor(); } @@ -100,7 +100,7 @@ public static Tensor Seq(DataType dataType, params int[] shape) /// /// create the seq value by gived datatype. /// - public static Tensor Seq(params int[] shape) + public static Tensor Seq(params long[] shape) where T : unmanaged, IEquatable { return Tensor.FromArray(Enumerable.Range(0, (int)TensorUtilities.GetProduct(shape)).ToArray()) diff --git a/src/Nncase.Tests.TestFixture/TransformBase/Compare.cs b/src/Nncase.Tests.TestFixture/TransformBase/Compare.cs index 97a1f26d1..518fe12dd 100644 --- a/src/Nncase.Tests.TestFixture/TransformBase/Compare.cs +++ b/src/Nncase.Tests.TestFixture/TransformBase/Compare.cs @@ -24,11 +24,11 @@ public static int GetChannelAxis(int[] shape) return Math.Max(0, 1 - (4 - shape.Length)); } - public static (int Channels, int Size) GetShapeInfo(int[] shape, int channelAxis = 1) + public static (long Channels, long Size) GetShapeInfo(long[] shape, int channelAxis = 1) { var i = channelAxis + 1; - var channels = shape[..i].Aggregate(1, (a, b) => a * b); - var size = shape[i..].Aggregate(1, (a, b) => a * b); + var channels = shape[..i].Aggregate(1L, (a, b) => a * b); + var size = shape[i..].Aggregate(1L, (a, b) => a * b); return (channels, size); } @@ -36,17 +36,17 @@ public static Tensor[] SliceByChannel(Tensor tensor) { var channelAxis = GetChannelAxis(tensor.Shape); var (channels, size) = GetShapeInfo(tensor.Dimensions.ToArray(), channelAxis); - return Enumerable.Range(0, channels).Select(i => + return Enumerable.Range(0, checked((int)channels)).Select(i => SliceTensor(tensor, size * i, size, channelAxis)) .ToArray(); } - private static Tensor SliceTensor(Tensor tensor, int start, int length, int channelAxis = 1) + private static Tensor SliceTensor(Tensor tensor, long start, long length, int channelAxis = 1) { var s = tensor.ElementType.SizeInBytes; return Tensor.FromBytes( tensor.ElementType, - tensor.BytesBuffer.Slice(start * s, length * s).ToArray(), + tensor.BytesBuffer.Slice(checked((int)start * s), checked((int)length * s)).ToArray(), tensor.Dimensions[(channelAxis + 1)..]); } } @@ -436,13 +436,13 @@ public void FailedAssert() public record DetailCompareResultInfo(float[] CosList, AccuracyLossInfo[] AccuracyLossInfos) { - public int[] Shape => AccuracyLossInfos.First().Shape; + public long[] Shape => AccuracyLossInfos.First().Shape; public IEnumerable Enumerable() { var tensorShape = Shape; var (channels, size) = GetShapeInfo(tensorShape, GetChannelAxis(tensorShape)); - return System.Linq.Enumerable.Range(0, channels).Select(c => new CompareResultByChannel(CosList[c], AccuracyLossInfos[(c * size)..((c + 1) * size)])); + return System.Linq.Enumerable.Range(0, checked((int)channels)).Select(c => new CompareResultByChannel(CosList[c], AccuracyLossInfos[checked((int)(c * size))..checked((int)((c + 1) * size))])); } } @@ -460,7 +460,7 @@ public IEnumerable Enumerable() public record CompareResultByChannel(float Cos, AccuracyLossInfo[] LossInfo) { - public int[] Shape => LossInfo.First().Shape; + public long[] Shape => LossInfo.First().Shape; // todo: more analysis strategy public AccuracyLossInfo[] Losses => LossInfo.Where(deviation => @@ -481,7 +481,7 @@ public override string ToString() public record AccuracyLossInfo(float V1, float V2, int[] Index, OriginValue V1Tensor, OriginValue V2Tensor) { - public int[] Shape => V1Tensor.Value.AsTensor().Shape.ToValueArray(); + public long[] Shape => V1Tensor.Value.AsTensor().Shape.ToValueArray(); public float Loss => Math.Abs(V1 - V2); diff --git a/src/Nncase.Tests.TestFixture/TransformBase/DataGenerator.cs b/src/Nncase.Tests.TestFixture/TransformBase/DataGenerator.cs index ac05faf5a..269c2a9ea 100644 --- a/src/Nncase.Tests.TestFixture/TransformBase/DataGenerator.cs +++ b/src/Nncase.Tests.TestFixture/TransformBase/DataGenerator.cs @@ -78,8 +78,8 @@ public static Expr DefaultConv() var input = Random.Normal(DataTypes.Float32, new[] { 1, 3, 24, 32 }); var weights = Random.Normal(DataTypes.Float32, new[] { 16, 3, 3, 3 }).Evaluate(); var bias = Random.Normal(DataTypes.Float32, new[] { 16 }).Evaluate(); - var stride = Tensor.From(new[] { 1, 1 }, new[] { 2 }); - var dilation = Tensor.From(new[] { 1, 1 }, new[] { 2 }); + var stride = Tensor.From(new[] { 1, 1 }, [2]); + var dilation = Tensor.From(new[] { 1, 1 }, [2]); var padding = new[,] { { 0, 1 }, @@ -208,7 +208,7 @@ private static Tensor ParseTensor(DumpData dumpData) // data[1] // ... // data[n] - private static (DataType DataType, int[] Shape, string[] Data, int EndIndex) ParseDumpFile(string[] content, int baseIndex) + private static (DataType DataType, long[] Shape, string[] Data, int EndIndex) ParseDumpFile(string[] content, int baseIndex) { var dtIndex = baseIndex; var shapeIndex = baseIndex + 1; @@ -249,20 +249,20 @@ private static DumpData[] ParseDumpFile(string[] content) // format // shape: x x x x - private static int[] ParseShape(string shapeStr) + private static long[] ParseShape(string shapeStr) { var s = shapeStr.TrimEnd().Split(":")[1]; if (s == "scalar") { - return Array.Empty(); + return Array.Empty(); } - return s.Split(" ").Select(x => int.Parse(x)).ToArray(); + return s.Split(" ").Select(x => long.Parse(x)).ToArray(); } private static DataType ParseDataType(string dt) => DataType.FromTypeCode((Runtime.TypeCode)int.Parse(dt.Split(":")[1])); - private record DumpData(DataType Dt, int[] Shape, string[] Data) + private record DumpData(DataType Dt, long[] Shape, string[] Data) { } } diff --git a/src/Nncase.Tests/Affine/UnitTestFor.cs b/src/Nncase.Tests/Affine/UnitTestFor.cs index 1fc243380..451a82623 100644 --- a/src/Nncase.Tests/Affine/UnitTestFor.cs +++ b/src/Nncase.Tests/Affine/UnitTestFor.cs @@ -57,8 +57,8 @@ public void TestSimpleFor() AffineMap.FromCallable((AffineDomain m, AffineDomain n, AffineDomain k) => new AffineRange[] { new AffineRange(k.Offset, k.Extent), new AffineRange(n.Offset, n.Extent) }), }, AffineMap.FromCallable((AffineDomain m, AffineDomain n) => new AffineRange[] { new AffineRange(m.Offset, m.Extent), new AffineRange(n.Offset, n.Extent), new AffineRange(F.Affine.Dim(2), F.Affine.Extent(2)) })); - var a = Const.FromTensor(Tensor.FromScalar(1f, new[] { dimM, dimK })); - var b = Const.FromTensor(Tensor.FromScalar(2f, new[] { dimK, dimN })); + var a = Const.FromTensor(Tensor.FromScalar(1f, [dimM, dimK])); + var b = Const.FromTensor(Tensor.FromScalar(2f, [dimK, dimN])); var aT2 = F.Affine.For(2, aAccessMap, a[aAccessMap]); var bT2 = F.Affine.For(2, bAccessMap, b[bAccessMap]); diff --git a/src/Nncase.Tests/Core/IR/UnitTestConst.cs b/src/Nncase.Tests/Core/IR/UnitTestConst.cs index 45e4e4585..edd5ac1bd 100644 --- a/src/Nncase.Tests/Core/IR/UnitTestConst.cs +++ b/src/Nncase.Tests/Core/IR/UnitTestConst.cs @@ -197,7 +197,7 @@ public void TestFromTensorValue() [Fact] public void TestFromTupleValue() { - var dims = new int[] { 1, 3, 16, 16 }; + var dims = new long[] { 1, 3, 16, 16 }; var t1 = Tensor.Ones(dims); var t2 = Tensor.Zeros(dims); var tensors = new Tensor[] { t1, t2 }; diff --git a/src/Nncase.Tests/Core/IR/UnitTestDimension.cs b/src/Nncase.Tests/Core/IR/UnitTestDimension.cs index 2cc50f0a9..225eedd4a 100644 --- a/src/Nncase.Tests/Core/IR/UnitTestDimension.cs +++ b/src/Nncase.Tests/Core/IR/UnitTestDimension.cs @@ -34,7 +34,7 @@ public void TestKind() Assert.False(d1.IsUnknown); Assert.True(d1.IsFixed); - var d2 = Dimension.Unknown; + var d2 = Dimension.Unknown(); Assert.Equal(DimensionKind.Unknown, d2.Kind); Assert.True(d2.IsUnknown); Assert.False(d2.IsFixed); @@ -67,7 +67,7 @@ public void TestOperatorAdd() var v2 = 1; Dimension d1 = v1; Dimension d2 = v2; - Dimension d3 = Dimension.Unknown; + var d3 = Dimension.Unknown(); var d4 = d1 + d2; Assert.Equal(v1 + v2, d4.Value); @@ -88,7 +88,7 @@ public void TestOperatorSubtract() var v2 = 1; Dimension d1 = v1; Dimension d2 = v2; - Dimension d3 = Dimension.Unknown; + var d3 = Dimension.Unknown(); var d4 = d1 - d2; Assert.Equal(v1 - v2, d4.Value); @@ -105,7 +105,7 @@ public void TestOperatorMul() var v2 = 1; Dimension d1 = v1; Dimension d2 = v2; - Dimension d3 = Dimension.Unknown; + var d3 = Dimension.Unknown(); var d4 = d1 * d2; Assert.Equal(v1 * v2, d4.Value); @@ -122,7 +122,7 @@ public void TestOperatorDiv() var v2 = 1; Dimension d1 = v1; Dimension d2 = v2; - Dimension d3 = Dimension.Unknown; + var d3 = Dimension.Unknown(); var d4 = d1 / d2; Assert.Equal(v1 / v2, d4.Value); diff --git a/src/Nncase.Tests/Core/IR/UnitTestShape.cs b/src/Nncase.Tests/Core/IR/UnitTestShape.cs index b72e5837b..1a3cdfa21 100644 --- a/src/Nncase.Tests/Core/IR/UnitTestShape.cs +++ b/src/Nncase.Tests/Core/IR/UnitTestShape.cs @@ -267,10 +267,10 @@ public void TestInsertAndCloneOverload2() public void TestToValueList() { int index = 1; - var items = new int[] { 3, 2 }; + var items = new long[] { 3, 2 }; var dimensions = new Dimension[] { 3, 2 }; - var a = new int[] { 1, 2 }; - List expected = new(); + var a = new long[] { 1, 2 }; + List expected = new(); expected.AddRange(a); expected.InsertRange(index, items); @@ -283,10 +283,10 @@ public void TestToValueList() public void TestToValueArray() { int index = 1; - var items = new int[] { 3, 2 }; + var items = new long[] { 3, 2 }; var dimensions = new Dimension[] { 3, 2 }; - var a = new int[] { 1, 2 }; - List list = new(); + var a = new long[] { 1, 2 }; + List list = new(); list.AddRange(a); list.InsertRange(index, items); var expected = list.ToArray(); diff --git a/src/Nncase.Tests/Core/UnitTestDumpUtility.cs b/src/Nncase.Tests/Core/UnitTestDumpUtility.cs index f9261e6db..b16984f9d 100644 --- a/src/Nncase.Tests/Core/UnitTestDumpUtility.cs +++ b/src/Nncase.Tests/Core/UnitTestDumpUtility.cs @@ -19,8 +19,8 @@ public sealed class UnitTestDumpUtility [Fact] public void TestValueDumper() { - ValueDumper.DumpTensor(new TensorValue(new Tensor(new[] { 1 })), "./test1"); - ValueDumper.DumpTensors(new[] { new TensorValue(new Tensor(new[] { 1 })) }, "./test2"); + ValueDumper.DumpTensor(new TensorValue(new Tensor([1])), "./test1"); + ValueDumper.DumpTensors(new[] { new TensorValue(new Tensor([1])) }, "./test2"); Assert.True(File.Exists("./test1")); Assert.True(Directory.Exists("./test2")); } @@ -36,11 +36,11 @@ public void TestDumpUtility() DumpUtility.SerializeShape(new[] { 1, 1, 1 }); DumpUtility.PathJoinByCreate("./", "test4"); - DumpUtility.WriteBinFile("./test5", new Tensor(new[] { 1 })); + DumpUtility.WriteBinFile("./test5", new Tensor([1])); Assert.True(File.Exists("./test3")); Assert.True(Directory.Exists("./test4")); Assert.True(File.Exists("./test5")); - DumpUtility.WriteKmodelData(new Tensor[] { new Tensor(new[] { 1 }) }, new Tensor[] { new Tensor(new[] { 1 }) }, "./test3", "./", true); + DumpUtility.WriteKmodelData(new Tensor[] { new Tensor([1]) }, new Tensor[] { new Tensor([1]) }, "./test3", "./", true); Assert.True(File.Exists("./test.kmodel")); Assert.True(File.Exists("./kmodel.desc")); File.Delete("./test.kmodel"); @@ -49,8 +49,8 @@ public void TestDumpUtility() [Fact] public void TestBinFileUtil() { - BinFileUtil.WriteBinInputs(new Tensor[] { new Tensor(new[] { 1 }) }, "./"); - BinFileUtil.WriteBinOutputs(new Tensor[] { new Tensor(new[] { 1 }) }, "./"); + BinFileUtil.WriteBinInputs(new Tensor[] { new Tensor([1]) }, "./"); + BinFileUtil.WriteBinOutputs(new Tensor[] { new Tensor([1]) }, "./"); Assert.True(File.Exists("./input_0_0.bin")); Assert.True(File.Exists("./nncase_result_0.bin")); BinFileUtil.ReadBinFile("./nncase_result_0.bin", DataTypes.Float32, new Shape(1)); diff --git a/src/Nncase.Tests/Core/UnitTestExpression.cs b/src/Nncase.Tests/Core/UnitTestExpression.cs index 38c20c6cf..b860fa8d8 100644 --- a/src/Nncase.Tests/Core/UnitTestExpression.cs +++ b/src/Nncase.Tests/Core/UnitTestExpression.cs @@ -73,8 +73,8 @@ public void TestConstNotEqual() dict.TryAdd(b, 1); Assert.Equal(2, dict.Keys.Count); var arr = new float[] { -0.12399824f, -0.03634571f, 0.5353417f, -0.67039806f, 0.91027457f, -1.0752988f, 0.55657554f, -1.1045103f }; - a = Const.FromTensor(Tensor.From(arr, new[] { 8, 1, 1 })); - b = Const.FromTensor(Tensor.From(arr, new[] { 1, 8, 1 })); + a = Const.FromTensor(Tensor.From(arr, [8, 1, 1])); + b = Const.FromTensor(Tensor.From(arr, [1, 8, 1])); Assert.NotEqual(a, b); Assert.Equal(2, new HashSet(new[] { a, b }).Count); } @@ -212,7 +212,7 @@ public void TestBinaryOpEqualWithCheckType() [Fact] public void TestDenseTenorEqual() { - var t = new Tensor(new[] { 1, 2, 3, 4 }); + var t = Tensor.FromArray(new[] { 1, 2, 3, 4 }); var con = Const.FromTensor(t); var con1 = Const.FromTensor(t); Assert.Equal(con, con1); @@ -222,7 +222,7 @@ public void TestDenseTenorEqual() [Fact] public void TestConstToDenseTenor() { - var con = Const.FromTensor(Tensor.From(new[] { 1, 2, 3, 4, 5 }, new[] { 5 })); + var con = Const.FromTensor(Tensor.From(new[] { 1, 2, 3, 4, 5 }, [5])); var t = con.Value.Cast(); Assert.Equal(1, t[0]); Assert.Equal(2, t[1]); @@ -252,7 +252,7 @@ public void TestConstToDenseTenor() [Fact] public void TestDenseTensorLength() { - var t = new Tensor(new[] { 1, 2, 3, 4 }, new[] { 2, 2 }); + var t = new Tensor(new[] { 1, 2, 3, 4 }, [2, 2]); Assert.Equal(4, t.Length); Assert.Equal(2, t.Dimensions[0]); } diff --git a/src/Nncase.Tests/Core/UnitTestGetReplaceUtility.cs b/src/Nncase.Tests/Core/UnitTestGetReplaceUtility.cs index 02a54305a..fd821683c 100644 --- a/src/Nncase.Tests/Core/UnitTestGetReplaceUtility.cs +++ b/src/Nncase.Tests/Core/UnitTestGetReplaceUtility.cs @@ -22,7 +22,7 @@ public sealed class UnitTestGetReplaceUtility [Fact] public void TestGetReplaceUtility() { - Assert.Throws(() => Utility.Get4DGNNEShape(new[] { 0, 1, 2, 3, 4 })); + Assert.Throws(() => Utility.Get4DGNNEShape([0, 1, 2, 3, 4])); } [Fact] @@ -30,7 +30,7 @@ public void WithTmp4DShape_WhenGivenFunctionAndOutputShape_ShouldInsertRehsapeBe { var input = IR.F.Random.Normal(DataTypes.UInt8, new[] { 1, 2, 3, 4 }); Fx inputCtor = expr => IR.F.Tensors.Flatten(expr, 1); - var originOutShape = new[] { 1, 2, 3, 4 }; + var originOutShape = new long[] { 1, 2, 3, 4 }; var output = Utility.WithTmp4DShape(inputCtor, originOutShape)(input); diff --git a/src/Nncase.Tests/Core/UnitTestIValue.cs b/src/Nncase.Tests/Core/UnitTestIValue.cs index a6aee9a98..11587d7df 100644 --- a/src/Nncase.Tests/Core/UnitTestIValue.cs +++ b/src/Nncase.Tests/Core/UnitTestIValue.cs @@ -18,17 +18,17 @@ public sealed class UnitTestIValue public static IEnumerable TestTensorValueCountData => new[] { - new object[] { Tensor.Ones(new int[] { 1, 3, 16, 16 }) }, - new object[] { Tensor.Zeros(new int[] { 1, 3, 16, 16 }) }, - new object[] { Tensor.From(new int[] { 1, 2, 3, 4 }, new int[] { 2, 2 }) }, + new object[] { Tensor.Ones([1, 3, 16, 16]) }, + new object[] { Tensor.Zeros([1, 3, 16, 16]) }, + new object[] { Tensor.From(new int[] { 1, 2, 3, 4 }, [2, 2]) }, }; public static IEnumerable TestTupleValueCountData => new[] { - new object[] { new Tensor[] { Tensor.Ones(new int[] { 1, 3, 16, 16 }) } }, - new object[] { new Tensor[] { Tensor.Ones(new int[] { 1, 3, 16, 16 }), Tensor.Zeros(new int[] { 1, 3, 16, 16 }) } }, - new object[] { new Tensor[] { Tensor.Ones(new int[] { 1, 3, 16, 16 }), Tensor.Zeros(new int[] { 1, 3, 16, 16 }), Tensor.From(new int[] { 1, 2, 3, 4 }, new int[] { 2, 2 }) } }, + new object[] { new Tensor[] { Tensor.Ones([1, 3, 16, 16]) } }, + new object[] { new Tensor[] { Tensor.Ones([1, 3, 16, 16]), Tensor.Zeros([1, 3, 16, 16]) } }, + new object[] { new Tensor[] { Tensor.Ones([1, 3, 16, 16]), Tensor.Zeros([1, 3, 16, 16]), Tensor.From(new int[] { 1, 2, 3, 4 }, [2, 2]) } }, }; [Fact] @@ -64,7 +64,7 @@ public void TestNoneValueException() [Fact] public void TestTensorValueType() { - var dims = new int[] { 1, 3, 16, 16 }; + var dims = new long[] { 1, 3, 16, 16 }; var a = new TensorValue(Tensor.Ones(dims)); Assert.Equal(new TensorType(DataTypes.Float32, dims), a.Type); } @@ -80,7 +80,7 @@ public void TestTensorValueCount(Tensor t) [Fact] public void TestTensorValueIndex() { - var ones = Tensor.Ones(new int[] { 1, 3, 16, 16 }); + var ones = Tensor.Ones([1, 3, 16, 16]); var a = new TensorValue(ones); Assert.Equal(a, a[0]); Assert.Throws(() => a[1]); @@ -89,7 +89,7 @@ public void TestTensorValueIndex() [Fact] public void TestTensorValueAsTensor() { - var ones = Tensor.Ones(new int[] { 1, 3, 16, 16 }); + var ones = Tensor.Ones([1, 3, 16, 16]); var a = new TensorValue(ones); Assert.Equal(ones, a.AsTensor()); } @@ -97,7 +97,7 @@ public void TestTensorValueAsTensor() [Fact] public void TestTensorValueAsTensors() { - var ones = Tensor.Ones(new int[] { 1, 3, 16, 16 }); + var ones = Tensor.Ones([1, 3, 16, 16]); var a = new TensorValue(ones); Assert.Equal(ones, a.AsTensors()[0]); } @@ -105,8 +105,8 @@ public void TestTensorValueAsTensors() [Fact] public void TestTensorValueCompare() { - var ones = Tensor.Ones(new int[] { 1, 3, 16, 16 }); - var zeros = Tensor.Zeros(new int[] { 1, 3, 16, 16 }); + var ones = Tensor.Ones([1, 3, 16, 16]); + var zeros = Tensor.Zeros([1, 3, 16, 16]); var a = new TensorValue(ones); var b = a; var c = new TensorValue(ones); @@ -127,7 +127,7 @@ public void TestTensorValueCompare() [Fact] public void TestTensorValueGetHashCode() { - var ones = Tensor.Ones(new int[] { 1, 3, 16, 16 }); + var ones = Tensor.Ones([1, 3, 16, 16]); var a = new TensorValue(ones); Assert.Equal(HashCode.Combine(ones), a.GetHashCode()); } @@ -135,7 +135,7 @@ public void TestTensorValueGetHashCode() [Fact] public void TestTupleValueType() { - var dims = new int[] { 1, 3, 16, 16 }; + var dims = new long[] { 1, 3, 16, 16 }; var tensor1 = Tensor.Ones(dims); var tensor2 = Tensor.Zeros(dims); var tensors = new Tensor[] { tensor1, tensor2 }; @@ -154,7 +154,7 @@ public void TestTupleValueCount(Tensor[] tensors) [Fact] public void TestTupleValueIndex() { - var dims = new int[] { 1, 3, 16, 16 }; + var dims = new long[] { 1, 3, 16, 16 }; var tensor1 = Tensor.Ones(dims); var tensor2 = Tensor.Zeros(dims); var tensors = new Tensor[] { tensor1, tensor2 }; @@ -166,7 +166,7 @@ public void TestTupleValueIndex() [Fact] public void TestTupleValueAsTensor() { - var dims = new int[] { 1, 3, 16, 16 }; + var dims = new long[] { 1, 3, 16, 16 }; var tensor1 = Tensor.Ones(dims); var tensor2 = Tensor.Zeros(dims); var tensors = new Tensor[] { tensor1, tensor2 }; @@ -177,7 +177,7 @@ public void TestTupleValueAsTensor() [Fact] public void TestTupleValueAsTensors() { - var dims = new int[] { 1, 3, 16, 16 }; + var dims = new long[] { 1, 3, 16, 16 }; var tensor1 = Tensor.Ones(dims); var tensor2 = Tensor.Zeros(dims); var tensors = new Tensor[] { tensor1, tensor2 }; @@ -189,8 +189,8 @@ public void TestTupleValueAsTensors() [Fact] public void TestTupleValueCompare() { - var ones = Tensor.Ones(new int[] { 1, 3, 16, 16 }); - var zeros = Tensor.Zeros(new int[] { 1, 3, 16, 16 }); + var ones = Tensor.Ones([1, 3, 16, 16]); + var zeros = Tensor.Zeros([1, 3, 16, 16]); var tensors = new Tensor[] { ones, zeros }; var a = Value.FromTensors(tensors); var b = a; @@ -204,7 +204,7 @@ public void TestTupleValueCompare() [Fact] public void TestTupleValueGetHashCode() { - var ones = Tensor.Ones(new int[] { 1, 3, 16, 16 }); + var ones = Tensor.Ones([1, 3, 16, 16]); var tensors = new Tensor[] { ones, ones }; var values = tensors.Select(x => new TensorValue(x)).ToArray(); var a = new TupleValue(values); diff --git a/src/Nncase.Tests/Core/UnitTestTensor.cs b/src/Nncase.Tests/Core/UnitTestTensor.cs index 5e1a46922..f26a935ec 100644 --- a/src/Nncase.Tests/Core/UnitTestTensor.cs +++ b/src/Nncase.Tests/Core/UnitTestTensor.cs @@ -16,7 +16,7 @@ public sealed class UnitTestTensor public void TestICollection() { var a = new float[] { 1, 2, 3, 4, 5, 6, 7, 8 }; - var t = (ICollection)Tensor.From(a, new int[] { 1, 1, 2, 4 }); + var t = (ICollection)Tensor.From(a, [1, 1, 2, 4]); Assert.Equal(a.Length, t.Count); Assert.False(t.IsSynchronized); Assert.Equal((object)t, t.SyncRoot); @@ -26,7 +26,7 @@ public void TestICollection() public void TestIList() { var a = new float[] { 1, 2, 3, 4, 5, 6, 7, 8 }; - var t = Tensor.From(a, new int[] { 1, 1, 2, 4 }); + var t = Tensor.From(a, [1, 1, 2, 4]); var list = (IList)t; Assert.True(list.IsFixedSize); Assert.False(list.IsReadOnly); @@ -36,7 +36,7 @@ public void TestIList() Assert.Equal(100f, list[0]); list.Clear(); - var expected = Tensor.Zeros(new int[] { 1, 1, 2, 4 }); + var expected = Tensor.Zeros([1, 1, 2, 4]); Assert.Equal(expected, t); } @@ -44,11 +44,11 @@ public void TestIList() public void TestIndices() { var a = new float[] { 1, 2, 3, 4, 5, 6, 7, 8 }; - var t = Tensor.From(a, new int[] { 1, 1, 2, 4 }); + var t = Tensor.From(a, [1, 1, 2, 4]); - Assert.Equal(7, t[new int[] { 0, 0, 1, 2 }]); - t[new int[] { 0, 0, 1, 2 }] = 700; - Assert.Equal(700, t[new int[] { 0, 0, 1, 2 }]); + Assert.Equal(7, t[0, 0, 1, 2]); + t[0, 0, 1, 2] = 700; + Assert.Equal(700, t[0, 0, 1, 2]); } // Tensor FromBytes(Memory memory, ReadOnlySpan dimensions) @@ -57,7 +57,7 @@ public void TestFromBytesOverload1() { var a = new byte[] { 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40, 0x00, 0x00, 0x80, 0x40 }; var expected = new float[] { 1, 2, 3, 4 }; - var t = Tensor.FromBytes(new Memory(a), new int[] { 1, 1, 2, 2 }); + var t = Tensor.FromBytes(new Memory(a), [1, 1, 2, 2]); Assert.Equal(DataTypes.Float32, t.ElementType); Assert.Equal(expected, t.ToArray()); } @@ -68,7 +68,7 @@ public void TestFromBytesOverload2() { var a = new byte[] { 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40, 0x00, 0x00, 0x80, 0x40 }; var expected = new float[] { 1, 2, 3, 4 }; - var t = Tensor.FromBytes(DataTypes.Float32, new Memory(a), new int[] { 1, 1, 2, 2 }); + var t = Tensor.FromBytes(DataTypes.Float32, new Memory(a), [1, 1, 2, 2]); Assert.Equal(DataTypes.Float32, t.ElementType); Assert.Equal(expected, t.ToArray()); } @@ -101,7 +101,7 @@ public void TestFromBytesOverload5() { var a = new byte[] { 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40, 0x00, 0x00, 0x80, 0x40 }; var expected = new Vector4[] { Vector4.Create(new[] { 1.0f, 2.0f, 3.0f, 4.0f }) }; - var t = Tensor.FromBytes>(new Memory(a), new[] { 1 }); + var t = Tensor.FromBytes>(new Memory(a), [1]); Assert.Equal(new VectorType(DataTypes.Float32, 4), t.ElementType); Assert.Equal(expected, t.ToArray>()); } @@ -111,7 +111,7 @@ public void TestFromBytesWithPad() { var a = new byte[] { 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }; var expected = new Vector4[] { Vector4.Create(new[] { 1.0f, 0.0f, 0.0f, 0.0f }) }; - var t = Tensor.FromBytes>(new Memory(a), new int[] { 1 }); + var t = Tensor.FromBytes>(new Memory(a), [1]); Assert.Equal(new VectorType(DataTypes.Float32, 4), t.ElementType); Assert.Equal(expected, t.ToArray>()); } @@ -153,7 +153,7 @@ public unsafe void TestFromPointerET() public unsafe void TestFromConstOverload1() { var a = new float[] { 1, 2, 3, 4, 5, 6, 7, 8 }; - var t = Tensor.From(a, new int[] { 1, 1, 2, 4 }); + var t = Tensor.From(a, [1, 1, 2, 4]); var tensorConst1 = new TensorConst(t); var tensorConst2 = tensorConst1; @@ -170,7 +170,7 @@ public unsafe void TestFromConstOverload1() public unsafe void TestFromConstOverload2() { var a = new int[] { 1, 2, 3, 4, 5, 6, 7, 8 }; - var t1 = Tensor.From(a, new int[] { 1, 1, 2, 4 }); + var t1 = Tensor.From(a, [1, 1, 2, 4]); var tensorConst1 = new TensorConst(t1); var expected = new float[] { 1, 2, 3, 4, 5, 6, 7, 8 }; @@ -182,7 +182,7 @@ public unsafe void TestFromConstOverload2() public void TestListException() { var a = new float[] { 1, 2, 3, 4, 5, 6, 7, 8 }; - var t = (IList)Tensor.From(a, new int[] { 1, 1, 2, 4 }); + var t = (IList)Tensor.From(a, [1, 1, 2, 4]); Assert.Throws(() => t.Add(100)); Assert.Throws(() => t.Insert(0, 100)); diff --git a/src/Nncase.Tests/Core/UnitTestTensorHelper.cs b/src/Nncase.Tests/Core/UnitTestTensorHelper.cs index ac9dc5f43..8f7707f14 100644 --- a/src/Nncase.Tests/Core/UnitTestTensorHelper.cs +++ b/src/Nncase.Tests/Core/UnitTestTensorHelper.cs @@ -15,7 +15,7 @@ public sealed class UnitTestTensorHelper public void TestToArray() { var a = new float[] { 1, 1, 1, 1, 1, 1, 1, 1 }; - var t = Tensor.Ones(new int[] { 1, 1, 2, 4 }); + var t = Tensor.Ones([1, 1, 2, 4]); var b = t.ToArray(); Assert.Equal(a, b); } @@ -23,14 +23,14 @@ public void TestToArray() [Fact] public void TestToScalar1() { - var t = Tensor.Ones(new int[] { 1 }); + var t = Tensor.Ones([1]); Assert.Equal(1F, t.ToScalar()); } [Fact] public void TestToScalar2() { - var t = Tensor.Ones(new int[] { 1, 3, 16, 16 }); + var t = Tensor.Ones([1, 3, 16, 16]); Assert.Throws(() => t.ToScalar()); } @@ -41,10 +41,10 @@ public void TestToStr() string expected = "hello, world!"; var bytes = utf8.GetBytes(expected); - var t1 = Tensor.FromBytes(DataTypes.Utf8Char, new Memory(bytes), new int[] { bytes.Length }); + var t1 = Tensor.FromBytes(DataTypes.Utf8Char, new Memory(bytes), [bytes.Length]); Assert.Equal(expected, t1.ToStr()); - var t2 = Tensor.Ones(new int[] { 1, 3, 16, 16 }); + var t2 = Tensor.Ones([1, 3, 16, 16]); Assert.Throws(() => t2.ToStr()); } } diff --git a/src/Nncase.Tests/Core/UnitTestTensorOfT.Helper.cs b/src/Nncase.Tests/Core/UnitTestTensorOfT.Helper.cs index b07a1e885..1fc7e1fdd 100644 --- a/src/Nncase.Tests/Core/UnitTestTensorOfT.Helper.cs +++ b/src/Nncase.Tests/Core/UnitTestTensorOfT.Helper.cs @@ -14,7 +14,7 @@ public sealed class UnitTestTensorOfTHelper public void TestToArray() { var a = new float[] { 1, 2, 3, 4, 5, 6, 7, 8 }; - var tensor = Tensor.From(a, new int[] { 1, 1, 2, 4 }); + var tensor = Tensor.From(a, [1, 1, 2, 4]); var b = tensor.ToArray(); Assert.Equal(a, b); } @@ -27,7 +27,7 @@ public void TestToScalar() var t1 = (Tensor)scalar; Assert.Equal(scalar, t1.ToScalar()); - var t2 = new Tensor(new int[] { 1, 3, 16, 16 }); + var t2 = new Tensor([1, 3, 16, 16]); Assert.Throws(() => t2.ToScalar()); } } diff --git a/src/Nncase.Tests/Core/UnitTestTensorOfT.cs b/src/Nncase.Tests/Core/UnitTestTensorOfT.cs index 7f276e882..f3be525bb 100644 --- a/src/Nncase.Tests/Core/UnitTestTensorOfT.cs +++ b/src/Nncase.Tests/Core/UnitTestTensorOfT.cs @@ -29,7 +29,7 @@ public void TestICollection() { var a1 = new int[] { 1, 2, 3, 4, 5, 6, 7, 8 }; var memory = new Memory(a1); - var t = new Tensor(memory, new int[] { 1, 1, 2, 4 }); + var t = new Tensor(memory, [1, 1, 2, 4]); ICollection c = t; Assert.Equal(a1.Length, c.Count); Assert.False(c.IsReadOnly); @@ -66,7 +66,7 @@ public void TestIReadOnlyList() { var a = new int[] { 1, 2, 3, 4, 5, 6, 7, 8 }; var memory = new Memory(a); - var t = new Tensor(memory, new int[] { 8 }); + var t = new Tensor(memory, [8]); var list = (IReadOnlyList)t; for (int i = 0; i < a.Length; i++) { @@ -79,7 +79,7 @@ public void TestIList() { var a = new int[] { 1, 2, 3, 4, 5, 6, 7, 8 }; var memory = new Memory(a); - var t = new Tensor(memory, new int[] { 8 }); + var t = new Tensor(memory, [8]); var list = (IList)t; for (int i = 0; i < a.Length; i++) @@ -110,7 +110,7 @@ public void TestClone() { var a = new int[] { 1, 2, 3, 4, 5, 6, 7, 8 }; var memory = new Memory(a); - var t1 = new Tensor(memory, new int[] { 8 }); + var t1 = new Tensor(memory, [8]); var t2 = t1.Clone(); Assert.Equal(t1, t2); Assert.NotSame(t1, t2); @@ -119,8 +119,8 @@ public void TestClone() [Fact] public void TestCloneEmpty() { - var t1 = new Tensor(new int[] { 1, 2, 3, 4 }); - var t2 = t1.CloneEmpty(new int[] { 1, 3, 16, 16 }); + var t1 = new Tensor([1, 2, 3, 4]); + var t2 = t1.CloneEmpty([1, 3, 16, 16]); var a = (ICollection)t2; Assert.Equal(1 * 3 * 16 * 16, a.Count); } @@ -130,7 +130,7 @@ public void TestGetValue() { var a = new int[] { 1, 2, 3, 4, 5, 6, 7, 8 }; var memory = new Memory(a); - var t = new Tensor(memory, new int[] { 2, 4 }); + var t = new Tensor(memory, [2, 4]); Assert.Equal(2, t.GetValue(1)); Assert.Equal(7, t.GetValue(6)); Assert.Throws(() => t.GetValue(a.Length)); @@ -142,10 +142,10 @@ public void TestReshape() var a = new int[] { 1, 2, 3, 4, 5, 6, 7, 8 }; var memory = new Memory(a); - var t1 = new Tensor(memory, new int[] { 1, 1, 2, 4 }); - Assert.Throws(() => t1.Reshape(new int[] { 1, 1, 4, 3 })); + var t1 = new Tensor(memory, [1, 1, 2, 4]); + Assert.Throws(() => t1.Reshape([1, 1, 4, 3])); - var t2 = t1.Reshape(new int[] { 1, 1, 4, 2 }); + var t2 = t1.Reshape([1, 1, 4, 2]); Assert.True(t2.Buffer.Span.SequenceEqual(t1.Buffer.Span)); } @@ -155,7 +155,7 @@ public void TestSetValue(int index, int value) { var a = new int[] { 1, 2, 3, 4, 5, 6, 7, 8 }; var memory = new Memory(a); - var t = new Tensor(memory, new int[] { 2, 4 }); + var t = new Tensor(memory, [2, 4]); t.SetValue(index, value); Assert.Equal(value, t.GetValue(index)); Assert.Throws(() => t.SetValue(a.Length, 9)); @@ -166,7 +166,7 @@ public void TestFill() { var a = new int[] { 1, 2, 3, 4, 5, 6, 7, 8 }; var memory = new Memory(a); - var t = new Tensor(memory, new int[] { 2, 4 }); + var t = new Tensor(memory, [2, 4]); int value = 100; t.Fill(value); for (int i = 0; i < a.Length; i++) @@ -180,7 +180,7 @@ public void TestContains() { var a = new int[] { 1, 2, 3, 4, 5, 6, 7, 8 }; var memory = new Memory(a); - var t = new Tensor(memory, new int[] { 2, 4 }); + var t = new Tensor(memory, [2, 4]); Assert.Contains(1, t); Assert.Contains(8, t); } @@ -190,19 +190,19 @@ public void TestTryGetIndicesOf() { var a = new int[] { 1, 2, 3, 4, 5, 6, 7, 8 }; var memory = new Memory(a); - var t = new Tensor(memory, new int[] { 1, 1, 2, 4 }); + var t = new Tensor(memory, [1, 1, 2, 4]); - var indices1 = new int[] { 1, 1, 1, 1 }; + var indices1 = new long[] { 1, 1, 1, 1 }; Assert.True(t.TryGetIndicesOf(7, indices1)); Assert.Equal(0, indices1[0]); Assert.Equal(0, indices1[1]); Assert.Equal(1, indices1[2]); Assert.Equal(2, indices1[3]); - var indices2 = new int[] { 1, 1 }; + var indices2 = new long[] { 1, 1 }; Assert.Throws(() => t.TryGetIndicesOf(7, indices2)); - var indices3 = new int[] { 1, 1, 1, 1 }; + var indices3 = new long[] { 1, 1, 1, 1 }; Assert.False(t.TryGetIndicesOf(100, indices3)); } @@ -213,14 +213,14 @@ public void TestEqualsOverload1() var a = new int[] { 1, 2, 3, 4, 5, 6, 7, 8 }; var memory = new Memory(a); - var t1 = new Tensor(memory, new int[] { 1, 1, 2, 4 }); - var t2 = new Tensor(memory, new int[] { 2, 4 }); + var t1 = new Tensor(memory, [1, 1, 2, 4]); + var t2 = new Tensor(memory, [2, 4]); Assert.False(t1.Equals((object)t2)); - var t3 = new Tensor(memory, new int[] { 1, 1, 4, 2 }); + var t3 = new Tensor(memory, [1, 1, 4, 2]); Assert.False(t1.Equals((object)t3)); - var t4 = new Tensor(memory, new int[] { 1, 1, 2, 4 }); + var t4 = new Tensor(memory, [1, 1, 2, 4]); Assert.True(t1.Equals((object)t4)); Assert.False(t1.Equals(a)); @@ -233,18 +233,18 @@ public void TestEqualsOverload2() var a = new int[] { 1, 2, 3, 4, 5, 6, 7, 8 }; var memory = new Memory(a); - var t1 = new Tensor(memory, new int[] { 1, 1, 2, 4 }); + var t1 = new Tensor(memory, [1, 1, 2, 4]); Tensor? n = null; Assert.NotStrictEqual(t1, n); Assert.False(t1.Equals(n)); - var t2 = new Tensor(memory, new int[] { 2, 4 }); + var t2 = new Tensor(memory, [2, 4]); Assert.NotStrictEqual(t1, t2); - var t3 = new Tensor(memory, new int[] { 1, 1, 4, 2 }); + var t3 = new Tensor(memory, [1, 1, 4, 2]); Assert.NotStrictEqual(t1, t3); - var t4 = new Tensor(memory, new int[] { 1, 1, 2, 4 }); + var t4 = new Tensor(memory, [1, 1, 2, 4]); Assert.StrictEqual(t1, t4); } } diff --git a/src/Nncase.Tests/Core/UnitTestTensorUtilities.cs b/src/Nncase.Tests/Core/UnitTestTensorUtilities.cs index 7e8a70c2b..bb66552dc 100644 --- a/src/Nncase.Tests/Core/UnitTestTensorUtilities.cs +++ b/src/Nncase.Tests/Core/UnitTestTensorUtilities.cs @@ -16,104 +16,104 @@ public sealed class UnitTestTensorUtilities new[] { new object[] { 23, new int[] { 24, 12, 4, 1 }, new int[] { 0, 1, 2, 3 }, 0 }, - new object[] { 23, new int[] { 24, 12, 4, 1 }, new int[] { 0, 1, 2, 3 }, 1 }, - new object[] { 11, new int[] { 24, 12, 4, 1 }, new int[] { 0, 1, 2, 3 }, 2 }, - new object[] { 3, new int[] { 24, 12, 4, 1 }, new int[] { 0, 1, 2, 3 }, 3 }, + [23, new int[] { 24, 12, 4, 1 }, new int[] { 0, 1, 2, 3 }, 1], + [11, new int[] { 24, 12, 4, 1 }, new int[] { 0, 1, 2, 3 }, 2], + [3, new int[] { 24, 12, 4, 1 }, new int[] { 0, 1, 2, 3 }, 3], }; public static unsafe IEnumerable TestGetIndexOverload2Data => new[] { new object[] { 23, new Expr[] { 24, 12, 4, 1 }, new Expr[] { 0, 1, 2, 3 }, 0 }, - new object[] { 23, new Expr[] { 24, 12, 4, 1 }, new Expr[] { 0, 1, 2, 3 }, 1 }, - new object[] { 11, new Expr[] { 24, 12, 4, 1 }, new Expr[] { 0, 1, 2, 3 }, 2 }, - new object[] { 3, new Expr[] { 24, 12, 4, 1 }, new Expr[] { 0, 1, 2, 3 }, 3 }, + [23, new Expr[] { 24, 12, 4, 1 }, new Expr[] { 0, 1, 2, 3 }, 1], + [11, new Expr[] { 24, 12, 4, 1 }, new Expr[] { 0, 1, 2, 3 }, 2], + [3, new Expr[] { 24, 12, 4, 1 }, new Expr[] { 0, 1, 2, 3 }, 3], }; public static unsafe IEnumerable TestSplitStridesData => new[] { new object[] { new int[] { 24, 12, 4, 1 }, Array.Empty(), new int[] { 24, 12, 4, 1 }, Array.Empty(), new int[4], 0, Array.Empty(), 0 }, - new object[] { new int[] { 12, 4, 1 }, new int[] { 24 }, new int[] { 24, 12, 4, 1 }, new int[] { 0 }, new int[3], 0, new int[1], 0 }, - new object[] { new int[] { 24, 4, 1 }, new int[] { 12 }, new int[] { 24, 12, 4, 1 }, new int[] { 1 }, new int[3], 0, new int[1], 0 }, - new object[] { new int[] { 24, 12, 1 }, new int[] { 4 }, new int[] { 24, 12, 4, 1 }, new int[] { 2 }, new int[3], 0, new int[1], 0 }, - new object[] { new int[] { 24, 12, 4 }, new int[] { 1 }, new int[] { 24, 12, 4, 1 }, new int[] { 3 }, new int[3], 0, new int[1], 0 }, - new object[] { new int[] { 4, 1, }, new int[] { 24, 12 }, new int[] { 24, 12, 4, 1 }, new int[] { 0, 1 }, new int[2], 0, new int[2], 0 }, - new object[] { new int[] { 12, 1 }, new int[] { 24, 4 }, new int[] { 24, 12, 4, 1 }, new int[] { 0, 2 }, new int[2], 0, new int[2], 0 }, - new object[] { Array.Empty(), new int[] { 24, 12, 4, 1 }, new int[] { 24, 12, 4, 1 }, new int[] { 0, 1, 2, 3 }, Array.Empty(), 0, new int[4], 0 }, + [new int[] { 12, 4, 1 }, new int[] { 24 }, new int[] { 24, 12, 4, 1 }, new int[] { 0 }, new int[3], 0, new int[1], 0], + [new int[] { 24, 4, 1 }, new int[] { 12 }, new int[] { 24, 12, 4, 1 }, new int[] { 1 }, new int[3], 0, new int[1], 0], + [new int[] { 24, 12, 1 }, new int[] { 4 }, new int[] { 24, 12, 4, 1 }, new int[] { 2 }, new int[3], 0, new int[1], 0], + [new int[] { 24, 12, 4 }, new int[] { 1 }, new int[] { 24, 12, 4, 1 }, new int[] { 3 }, new int[3], 0, new int[1], 0], + [new int[] { 4, 1, }, new int[] { 24, 12 }, new int[] { 24, 12, 4, 1 }, new int[] { 0, 1 }, new int[2], 0, new int[2], 0], + [new int[] { 12, 1 }, new int[] { 24, 4 }, new int[] { 24, 12, 4, 1 }, new int[] { 0, 2 }, new int[2], 0, new int[2], 0], + [Array.Empty(), new int[] { 24, 12, 4, 1 }, new int[] { 24, 12, 4, 1 }, new int[] { 0, 1, 2, 3 }, Array.Empty(), 0, new int[4], 0], }; public static unsafe IEnumerable TestTransformIndexByStridesData => new[] { - new object[] { 4, 20, new int[] { 24, 12, 4, 1 }, false, new int[] { 6, 2, 1, 1 } }, - new object[] { 5, 20, new int[] { 1, 4, 12, 24 }, true, new int[] { 6, 2, 1, 1 } }, + new object[] { 4, 20, new long[] { 24, 12, 4, 1 }, false, new long[] { 6, 2, 1, 1 } }, + [5, 20, new long[] { 1, 4, 12, 24 }, true, new long[] { 6, 2, 1, 1 }], }; [Fact] public void TestIsContiguousSlice() { - var dim1 = new[] { 1, 512, 14, 14 }; + var dim1 = new long[] { 1, 512, 14, 14 }; int start; Assert.True(TensorUtilities.IsContiguousSlice( dim1, - new[] { 0..1, 0..512, 0..14, 0..14 }, + [0..1, 0..512, 0..14, 0..14], out start)); Assert.Equal(0, start); Assert.True(TensorUtilities.IsContiguousSlice( dim1, - new[] { 0..1, 0..1, 0..1, 0..14 }, + [0..1, 0..1, 0..1, 0..14], out start)); Assert.Equal(0, start); Assert.True(TensorUtilities.IsContiguousSlice( dim1, - new[] { 0..1, 0..1, 0..1, 7..14 }, + [0..1, 0..1, 0..1, 7..14], out start)); Assert.Equal(0, start); Assert.True(TensorUtilities.IsContiguousSlice( dim1, - new[] { 0..1, 0..1, 7..14, 0..14 }, + [0..1, 0..1, 7..14, 0..14], out start)); Assert.Equal(0, start); Assert.False(TensorUtilities.IsContiguousSlice( dim1, - new[] { 0..1, 0..512, 0..7, 0..14 }, + [0..1, 0..512, 0..7, 0..14], out start)); Assert.Equal(2, start); Assert.False(TensorUtilities.IsContiguousSlice( dim1, - new[] { 0..1, 0..512, 0..7, 0..14, 0..1 }, + [0..1, 0..512, 0..7, 0..14, 0..1], out start)); Assert.Equal(4, start); Assert.False(TensorUtilities.IsContiguousSlice( dim1, - new[] { 0..1, 10..512, 0..1, 0..1 }, + [0..1, 10..512, 0..1, 0..1], out start)); Assert.Equal(2, start); Assert.False(TensorUtilities.IsContiguousSlice( dim1, - new[] { 0..1, 0..512, 0..7, 0..1 }, + [0..1, 0..512, 0..7, 0..1], out start)); Assert.Equal(3, start); - var dim2 = new[] { 1, 512, 1, 196 }; + var dim2 = new long[] { 1, 512, 1, 196 }; Assert.True(TensorUtilities.IsContiguousSlice( dim2, - new[] { 0..1, 0..128, 0..1, 0..196 }, + [0..1, 0..128, 0..1, 0..196], out start)); Assert.Equal(0, start); Assert.True(TensorUtilities.IsContiguousSlice( dim2, - new[] { 0..1, 0..1, 0..1, 10..15 }, + [0..1, 0..1, 0..1, 10..15], out start)); Assert.Equal(0, start); } @@ -147,26 +147,26 @@ public void TestGetProductOverload2() [Fact] public void TestIsAsending() { - var a = Enumerable.Range(1, 100).ToArray(); + var a = Enumerable.Range(1, 100).ToArray().ToLongs(); Assert.True(TensorUtilities.IsAscending(a)); - var b = Enumerable.Repeat(1, 100).ToArray(); + var b = Enumerable.Repeat(1, 100).ToArray().ToLongs(); Assert.True(TensorUtilities.IsAscending(b)); - var c = Enumerable.Range(1, 100).Reverse().ToArray(); + var c = Enumerable.Range(1, 100).Reverse().ToArray().ToLongs(); Assert.False(TensorUtilities.IsAscending(c)); } [Fact] public void TestIsDescending() { - var a = Enumerable.Range(1, 100).ToArray(); + var a = Enumerable.Range(1, 100).ToArray().ToLongs(); Assert.False(TensorUtilities.IsDescending(a)); - var b = Enumerable.Repeat(1, 100).ToArray(); + var b = Enumerable.Repeat(1, 100).ToArray().ToLongs(); Assert.True(TensorUtilities.IsDescending(b)); - var c = Enumerable.Range(1, 100).Reverse().ToArray(); + var c = Enumerable.Range(1, 100).Reverse().ToArray().ToLongs(); Assert.True(TensorUtilities.IsDescending(c)); } @@ -215,11 +215,11 @@ public void TestGetStridesOverload2() [Fact] public void TestGetSize() { - var shapes = new[] { 1, 2, 4, 8 }; - var strides = new[] { 1, 1, 1, 1 }; + var shapes = new long[] { 1, 2, 4, 8 }; + var strides = new long[] { 1, 1, 1, 1 }; var elementSize = 1; var getSize = TensorUtilities.GetSize(shapes, strides, elementSize); - var result = 1; + long result = 1; for (int i = 0; i < shapes.Length; i++) { result += (shapes[i] - 1) * strides[i]; @@ -243,11 +243,11 @@ public void TestGetIndexOverload1Exception() { // stride is empty var stride1 = Array.Empty(); - Assert.Equal(0, TensorUtilities.GetIndex(stride1, new int[] { 0 })); + Assert.Equal(0, TensorUtilities.GetIndex(stride1, [0])); // exception - Assert.Throws(() => TensorUtilities.GetIndex(stride1, new int[] { 0, 1 })); - Assert.Throws(() => TensorUtilities.GetIndex(stride1, new int[] { 1 })); + Assert.Throws(() => TensorUtilities.GetIndex(stride1, [0, 1])); + Assert.Throws(() => TensorUtilities.GetIndex(stride1, [1])); } [Theory] @@ -268,7 +268,7 @@ public void TestGetIndexOverload2Exception() Assert.Equal(0, actual1.Evaluate().AsTensor().ToScalar()); // exception - Assert.Throws(() => TensorUtilities.GetIndex(stride1, new Expr[] { 0, 1 })); + Assert.Throws(() => TensorUtilities.GetIndex(stride1, [0, 1])); } [Theory] @@ -282,7 +282,7 @@ public void TestGetIndexOverload2(Expr expect, Expr[] strides, Expr[] indices, i [Theory] [MemberData(nameof(TestTransformIndexByStridesData))] - public void TestTransformIndexByStrides(int expect, int index, int[] sourceStrides, bool sourceReverseStride, int[] transformStrides) + public void TestTransformIndexByStrides(int expect, int index, long[] sourceStrides, bool sourceReverseStride, long[] transformStrides) { var actual = TensorUtilities.TransformIndexByStrides(index, sourceStrides, sourceReverseStride, transformStrides); Assert.Equal(expect, actual); @@ -291,8 +291,8 @@ public void TestTransformIndexByStrides(int expect, int index, int[] sourceStrid [Fact] public void TestIsContiguous() { - Assert.True(TensorUtilities.IsContiguous(new int[] { 1, 2, 3, 4 }, new int[] { 24, 12, 4, 1 })); - Assert.False(TensorUtilities.IsContiguous(new int[] { 1, 2, 3, 4 }, new int[] { 24, 12, 4 })); - Assert.False(TensorUtilities.IsContiguous(new int[] { 1, 2, 3, 4 }, new int[] { 24, 12, 3, 1 })); + Assert.True(TensorUtilities.IsContiguous([1, 2, 3, 4], [24, 12, 4, 1])); + Assert.False(TensorUtilities.IsContiguous([1, 2, 3, 4], [24, 12, 4])); + Assert.False(TensorUtilities.IsContiguous([1, 2, 3, 4], [24, 12, 3, 1])); } } diff --git a/src/Nncase.Tests/Core/UnitTestTypeInfer.cs b/src/Nncase.Tests/Core/UnitTestTypeInfer.cs index ae37574ba..261543a43 100644 --- a/src/Nncase.Tests/Core/UnitTestTypeInfer.cs +++ b/src/Nncase.Tests/Core/UnitTestTypeInfer.cs @@ -60,7 +60,7 @@ public UnitTestTypeInfer() public void TestInferBinary() { var a = new Var(new TensorType(DataTypes.Float32, new[] { 1, 5, 1 })); - var b = Tensor.FromScalar(1.0f, new[] { 1, 5, 3 }); + var b = Tensor.FromScalar(1.0f, [1, 5, 3]); var c = a + b; _ = CompilerServices.InferenceType(c); @@ -103,7 +103,7 @@ public void TestSlice() [Fact] public void TestSlice2() { - var input_a = new Var("input_a", new TensorType(DataTypes.Float32, new[] { Dimension.Unknown, Dimension.Unknown, Dimension.Unknown })); + var input_a = new Var("input_a", new TensorType(DataTypes.Float32, Shape.Unknown(3))); var repeats = IR.F.Tensors.Slice(IR.F.Tensors.ShapeOf(input_a), new[] { -2 }, new[] { -1 }, 1); Assert.True(CompilerServices.InferenceType(repeats)); Assert.True(repeats.CheckedShape.Rank == 1); @@ -284,6 +284,7 @@ public UnitTestDynamicTypeInfer() { } +#if false [Fact] public void TestRange() { @@ -293,6 +294,7 @@ public void TestRange() var r = Range(begin, end, step); CheckInferShape(r, Dimension.Unknown); } +#endif [Fact] public void TestConcat() @@ -317,10 +319,11 @@ public void TestBroadcastInfer() [Fact] public void TestBroadcastInfer2() { - var a = new TensorType(DataTypes.Float32, new Dimension[] { 1, Dimension.Unknown, 8192 }); + var dimUnk1 = Dimension.Unknown(); + var a = new TensorType(DataTypes.Float32, new Dimension[] { 1, dimUnk1, 8192 }); var b = new TensorType(DataTypes.Float32, new Dimension[] { 1 }); var result = TypeInference.BroadcastType(a, b); - Assert.Equal(new TensorType(DataTypes.Float32, new Dimension[] { 1, Dimension.Unknown, 8192 }), result); + Assert.Equal(new TensorType(DataTypes.Float32, new Dimension[] { 1, dimUnk1, 8192 }), result); } private void CheckInferShape(Expr expr, params Dimension[] shapeDimensions) diff --git a/src/Nncase.Tests/Diagnostics/UnitTestDumpper.cs b/src/Nncase.Tests/Diagnostics/UnitTestDumpper.cs index 16016b616..1130e8ab9 100644 --- a/src/Nncase.Tests/Diagnostics/UnitTestDumpper.cs +++ b/src/Nncase.Tests/Diagnostics/UnitTestDumpper.cs @@ -187,7 +187,7 @@ public async Task TestSubDumperDumpFlags() [Fact] public void TestDumperCSharpIRFunction() { - var x = IR.F.Math.Quantize(IR.F.Random.Normal(DataTypes.Float32, 0, 1, 0, new[] { 1, 2, 2, 2 }), Tensor.From(new QuantParam[] { new(1, 2.0f), new(2, 3.0f) }, new[] { 2 }), DataTypes.UInt8); + var x = IR.F.Math.Quantize(IR.F.Random.Normal(DataTypes.Float32, 0, 1, 0, new[] { 1, 2, 2, 2 }), Tensor.From(new QuantParam[] { new(1, 2.0f), new(2, 3.0f) }, [2]), DataTypes.UInt8); var y = new Var("y", new TensorType(DataTypes.UInt8, new int[] { 1, 2, 2, 2 })); var z = IR.F.Random.Normal(DataTypes.UInt8, 0, 1, 0, new[] { 1, 2, 2, 2 }); var m = IR.F.Random.Normal(DataTypes.UInt8, 0, 1, 0, new[] { 1, 20, 2, 2 }); diff --git a/src/Nncase.Tests/Distributed/UnitTestCustomOpScheme.cs b/src/Nncase.Tests/Distributed/UnitTestCustomOpScheme.cs index 29e18da6e..d750207fd 100644 --- a/src/Nncase.Tests/Distributed/UnitTestCustomOpScheme.cs +++ b/src/Nncase.Tests/Distributed/UnitTestCustomOpScheme.cs @@ -17,7 +17,7 @@ public class UnitTestCustomOpScheme : TestClassBase [Fact] public void TestExportScheme() { - var scheme = new CustomOpScheme("1", "matmul", new CustomOpScheme.Node[] { new CustomOpScheme.Node(string.Empty, "Matmul", new[] { new[] { 32, 32 }, new[] { 32, 32 } }, new[] { new SBP[] { SBP.B, SBP.B, SBP.B }, new SBP[] { SBP.B, SBP.B, SBP.S(1) } }, 1, string.Empty) }); + var scheme = new CustomOpScheme("1", "matmul", new CustomOpScheme.Node[] { new CustomOpScheme.Node(string.Empty, "Matmul", [[32, 32], [32, 32]], new[] { new SBP[] { SBP.B, SBP.B, SBP.B }, new SBP[] { SBP.B, SBP.B, SBP.S(1) } }, 1, string.Empty) }); var except = @"{ ""Version"": ""1"", ""Model"": ""matmul"", diff --git a/src/Nncase.Tests/Evaluator/UnitTestEvaluator.cs b/src/Nncase.Tests/Evaluator/UnitTestEvaluator.cs index 037e40793..54c3f09f0 100755 --- a/src/Nncase.Tests/Evaluator/UnitTestEvaluator.cs +++ b/src/Nncase.Tests/Evaluator/UnitTestEvaluator.cs @@ -81,7 +81,7 @@ public void TestTFResizeImage() var input = OrtKI.Random(1, 3, 224, 224).ToTensor(); var image = Imaging.ResizeImage(ImageResizeMode.Bilinear, input, Array.Empty(), new[] { 1, 3, 112, 112 }, isTFResize: true); image.InferenceType(); - Assert.Equal(new[] { 1, 3, 112, 112 }, image.Evaluate().AsTensor().Dimensions.ToArray()); + Assert.Equal([1, 3, 112, 112], image.Evaluate().AsTensor().Dimensions.ToArray()); } [Fact] @@ -90,7 +90,7 @@ public void TestOnnxResizeImage() var input = OrtKI.Random(1, 3, 224, 224).ToTensor(); var image = Imaging.ResizeImage(ImageResizeMode.Bilinear, input, Array.Empty(), new[] { 1, 3, 112, 112 }, isTFResize: false); image.InferenceType(); - Assert.Equal(new[] { 1, 3, 112, 112 }, image.Evaluate().AsTensor().Dimensions.ToArray()); + Assert.Equal([1, 3, 112, 112], image.Evaluate().AsTensor().Dimensions.ToArray()); } [Fact] diff --git a/src/Nncase.Tests/Evaluator/UnitTestEvaluatorMath.cs b/src/Nncase.Tests/Evaluator/UnitTestEvaluatorMath.cs index b8439d836..769d3c69d 100755 --- a/src/Nncase.Tests/Evaluator/UnitTestEvaluatorMath.cs +++ b/src/Nncase.Tests/Evaluator/UnitTestEvaluatorMath.cs @@ -27,11 +27,11 @@ namespace Nncase.Tests.EvaluatorTest; public class UnitTestEvaluatorMath : TestClassBase { - public static readonly TheoryData ClampInvalidTypeData = new() + public static readonly TheoryData ClampInvalidTypeData = new() { - { new[] { 1, 2, 3, 4 }, new[] { 8 }, new[] { 8 } }, - { new[] { 1, 2, 3, 4 }, new[] { 4 }, new[] { 8 } }, - { new[] { 1, 2, 3, 4 }, new[] { 4 }, new[] { 1 } }, + { [1, 2, 3, 4], [8], [8] }, + { [1, 2, 3, 4], [4], [8] }, + { [1, 2, 3, 4], [4], [1] }, }; [Fact] @@ -117,7 +117,7 @@ public void TestBinaryScalarTensor() var b = new bool[] { true, false, false, true, true, false, false, true }; if (op == BinaryOp.LogicalAnd || op == BinaryOp.LogicalOr || op == BinaryOp.LogicalXor) { - TestBinaryRunNormal(op, OrtKISharp.Tensor.FromScalar(a), OrtKISharp.Tensor.MakeTensor(b, new long[] { 2, 4 }), a, Tensor.From(b, new[] { 2, 4 })); + TestBinaryRunNormal(op, OrtKISharp.Tensor.FromScalar(a), OrtKISharp.Tensor.MakeTensor(b, new long[] { 2, 4 }), a, Tensor.From(b, [2, 4])); } } @@ -130,7 +130,7 @@ public void TestBinaryScalarTensor() // if (op != BinaryOp.LogicalAnd && op != BinaryOp.LogicalOr && op != BinaryOp.LogicalXor) if (op == BinaryOp.LeftShift || op == BinaryOp.RightShift) { - TestBinaryRunNormal(op, OrtKISharp.Tensor.FromScalar(a), OrtKISharp.Tensor.MakeTensor(b, new long[] { 2, 4 }), a, Tensor.From(b, new[] { 2, 4 })); + TestBinaryRunNormal(op, OrtKISharp.Tensor.FromScalar(a), OrtKISharp.Tensor.MakeTensor(b, new long[] { 2, 4 }), a, Tensor.From(b, [2, 4])); } } @@ -141,7 +141,7 @@ public void TestBinaryScalarTensor() var b = new float[] { 1, 2, 3, 4, 5, 6, 7, 8 }; if (op != BinaryOp.LogicalAnd && op != BinaryOp.LogicalOr && op != BinaryOp.LogicalXor && op != BinaryOp.LeftShift && op != BinaryOp.RightShift) { - TestBinaryRunNormal(op, OrtKISharp.Tensor.FromScalar(a), OrtKISharp.Tensor.MakeTensor(b, new long[] { 2, 4 }), a, Tensor.From(b, new[] { 2, 4 })); + TestBinaryRunNormal(op, OrtKISharp.Tensor.FromScalar(a), OrtKISharp.Tensor.MakeTensor(b, new long[] { 2, 4 }), a, Tensor.From(b, [2, 4])); } } @@ -152,7 +152,7 @@ public void TestBinaryScalarTensor() var b = new int[] { 1, 2, 3, 4, 5, 6, 7, 8 }; if (op != BinaryOp.LogicalAnd && op != BinaryOp.LogicalOr && op != BinaryOp.LogicalXor && op != BinaryOp.LeftShift && op != BinaryOp.RightShift) { - TestBinaryRunNormal(op, OrtKISharp.Tensor.FromScalar(a), OrtKISharp.Tensor.MakeTensor(b, new long[] { 2, 4 }), a, Tensor.From(b, new[] { 2, 4 })); + TestBinaryRunNormal(op, OrtKISharp.Tensor.FromScalar(a), OrtKISharp.Tensor.MakeTensor(b, new long[] { 2, 4 }), a, Tensor.From(b, [2, 4])); } } @@ -163,7 +163,7 @@ public void TestBinaryScalarTensor() var b = new long[] { 1, 2, 3, 4, 5, 6, 7, 8 }; if (op != BinaryOp.LogicalAnd && op != BinaryOp.LogicalOr && op != BinaryOp.LogicalXor && op != BinaryOp.LeftShift && op != BinaryOp.RightShift) { - TestBinaryRunNormal(op, OrtKISharp.Tensor.FromScalar(a), OrtKISharp.Tensor.MakeTensor(b, new long[] { 2, 4 }), a, Tensor.From(b, new[] { 2, 4 })); + TestBinaryRunNormal(op, OrtKISharp.Tensor.FromScalar(a), OrtKISharp.Tensor.MakeTensor(b, new long[] { 2, 4 }), a, Tensor.From(b, [2, 4])); } } } @@ -184,7 +184,7 @@ public void TestBinaryTensorScalar() var b = true; if (op == BinaryOp.LogicalAnd || op == BinaryOp.LogicalOr || op == BinaryOp.LogicalXor) { - TestBinaryRunNormal(op, OrtKISharp.Tensor.MakeTensor(a, new long[] { 2, 4 }), OrtKISharp.Tensor.FromScalar(b), Tensor.From(a, new[] { 2, 4 }), b); + TestBinaryRunNormal(op, OrtKISharp.Tensor.MakeTensor(a, new long[] { 2, 4 }), OrtKISharp.Tensor.FromScalar(b), Tensor.From(a, [2, 4]), b); } } @@ -197,7 +197,7 @@ public void TestBinaryTensorScalar() // if (op != BinaryOp.LogicalAnd && op != BinaryOp.LogicalOr && op != BinaryOp.LogicalXor) if (op == BinaryOp.LeftShift || op == BinaryOp.RightShift) { - TestBinaryRunNormal(op, OrtKISharp.Tensor.MakeTensor(a, new long[] { 2, 4 }), OrtKISharp.Tensor.FromScalar(b), Tensor.From(a, new[] { 2, 4 }), b); + TestBinaryRunNormal(op, OrtKISharp.Tensor.MakeTensor(a, new long[] { 2, 4 }), OrtKISharp.Tensor.FromScalar(b), Tensor.From(a, [2, 4]), b); } } @@ -208,7 +208,7 @@ public void TestBinaryTensorScalar() var b = 2f; if (op != BinaryOp.LogicalAnd && op != BinaryOp.LogicalOr && op != BinaryOp.LogicalXor && op != BinaryOp.LeftShift && op != BinaryOp.RightShift) { - TestBinaryRunNormal(op, OrtKISharp.Tensor.MakeTensor(a, new long[] { 2, 4 }), OrtKISharp.Tensor.FromScalar(b), Tensor.From(a, new[] { 2, 4 }), b); + TestBinaryRunNormal(op, OrtKISharp.Tensor.MakeTensor(a, new long[] { 2, 4 }), OrtKISharp.Tensor.FromScalar(b), Tensor.From(a, [2, 4]), b); } } @@ -219,7 +219,7 @@ public void TestBinaryTensorScalar() var b = 2; if (op != BinaryOp.LogicalAnd && op != BinaryOp.LogicalOr && op != BinaryOp.LogicalXor && op != BinaryOp.LeftShift && op != BinaryOp.RightShift) { - TestBinaryRunNormal(op, OrtKISharp.Tensor.MakeTensor(a, new long[] { 2, 4 }), OrtKISharp.Tensor.FromScalar(b), Tensor.From(a, new[] { 2, 4 }), b); + TestBinaryRunNormal(op, OrtKISharp.Tensor.MakeTensor(a, new long[] { 2, 4 }), OrtKISharp.Tensor.FromScalar(b), Tensor.From(a, [2, 4]), b); } } @@ -230,7 +230,7 @@ public void TestBinaryTensorScalar() var b = 2L; if (op != BinaryOp.LogicalAnd && op != BinaryOp.LogicalOr && op != BinaryOp.LogicalXor && op != BinaryOp.LeftShift && op != BinaryOp.RightShift) { - TestBinaryRunNormal(op, OrtKISharp.Tensor.MakeTensor(a, new long[] { 2, 4 }), OrtKISharp.Tensor.FromScalar(b), Tensor.From(a, new[] { 2, 4 }), b); + TestBinaryRunNormal(op, OrtKISharp.Tensor.MakeTensor(a, new long[] { 2, 4 }), OrtKISharp.Tensor.FromScalar(b), Tensor.From(a, [2, 4]), b); } } } @@ -251,7 +251,7 @@ public void TestBinaryTensorTensor() var b = new bool[] { true, false, true, false, true, false, false, true }; if (op == BinaryOp.LogicalAnd || op == BinaryOp.LogicalOr || op == BinaryOp.LogicalXor) { - TestBinaryRunNormal(op, OrtKISharp.Tensor.MakeTensor(a, new long[] { 2, 4 }), OrtKISharp.Tensor.MakeTensor(b, new long[] { 2, 4 }), Tensor.From(a, new[] { 2, 4 }), Tensor.From(b, new[] { 2, 4 })); + TestBinaryRunNormal(op, OrtKISharp.Tensor.MakeTensor(a, new long[] { 2, 4 }), OrtKISharp.Tensor.MakeTensor(b, new long[] { 2, 4 }), Tensor.From(a, [2, 4]), Tensor.From(b, [2, 4])); } } @@ -264,7 +264,7 @@ public void TestBinaryTensorTensor() // if (op != BinaryOp.LogicalAnd && op != BinaryOp.LogicalOr && op != BinaryOp.LogicalXor) if (op == BinaryOp.LeftShift || op == BinaryOp.RightShift) { - TestBinaryRunNormal(op, OrtKISharp.Tensor.MakeTensor(a, new long[] { 2, 4 }), OrtKISharp.Tensor.MakeTensor(b, new long[] { 2, 4 }), Tensor.From(a, new[] { 2, 4 }), Tensor.From(b, new[] { 2, 4 })); + TestBinaryRunNormal(op, OrtKISharp.Tensor.MakeTensor(a, new long[] { 2, 4 }), OrtKISharp.Tensor.MakeTensor(b, new long[] { 2, 4 }), Tensor.From(a, [2, 4]), Tensor.From(b, [2, 4])); } } @@ -275,7 +275,7 @@ public void TestBinaryTensorTensor() var b = new float[] { 1, 1, 2, 2, 3, 3, 4, 4 }; if (op != BinaryOp.LogicalAnd && op != BinaryOp.LogicalOr && op != BinaryOp.LogicalXor && op != BinaryOp.LeftShift && op != BinaryOp.RightShift) { - TestBinaryRunNormal(op, OrtKISharp.Tensor.MakeTensor(a, new long[] { 2, 4 }), OrtKISharp.Tensor.MakeTensor(b, new long[] { 2, 4 }), Tensor.From(a, new[] { 2, 4 }), Tensor.From(b, new[] { 2, 4 })); + TestBinaryRunNormal(op, OrtKISharp.Tensor.MakeTensor(a, new long[] { 2, 4 }), OrtKISharp.Tensor.MakeTensor(b, new long[] { 2, 4 }), Tensor.From(a, [2, 4]), Tensor.From(b, [2, 4])); } } @@ -286,7 +286,7 @@ public void TestBinaryTensorTensor() var b = new int[] { 1, 1, 2, 2, 3, 3, 4, 4 }; if (op != BinaryOp.LogicalAnd && op != BinaryOp.LogicalOr && op != BinaryOp.LogicalXor && op != BinaryOp.LeftShift && op != BinaryOp.RightShift) { - TestBinaryRunNormal(op, OrtKISharp.Tensor.MakeTensor(a, new long[] { 2, 4 }), OrtKISharp.Tensor.MakeTensor(b, new long[] { 2, 4 }), Tensor.From(a, new[] { 2, 4 }), Tensor.From(b, new[] { 2, 4 })); + TestBinaryRunNormal(op, OrtKISharp.Tensor.MakeTensor(a, new long[] { 2, 4 }), OrtKISharp.Tensor.MakeTensor(b, new long[] { 2, 4 }), Tensor.From(a, [2, 4]), Tensor.From(b, [2, 4])); } } @@ -297,7 +297,7 @@ public void TestBinaryTensorTensor() var b = new long[] { 1, 1, 2, 2, 3, 3, 4, 4 }; if (op != BinaryOp.LogicalAnd && op != BinaryOp.LogicalOr && op != BinaryOp.LogicalXor && op != BinaryOp.LeftShift && op != BinaryOp.RightShift) { - TestBinaryRunNormal(op, OrtKISharp.Tensor.MakeTensor(a, new long[] { 2, 4 }), OrtKISharp.Tensor.MakeTensor(b, new long[] { 2, 4 }), Tensor.From(a, new[] { 2, 4 }), Tensor.From(b, new[] { 2, 4 })); + TestBinaryRunNormal(op, OrtKISharp.Tensor.MakeTensor(a, new long[] { 2, 4 }), OrtKISharp.Tensor.MakeTensor(b, new long[] { 2, 4 }), Tensor.From(a, [2, 4]), Tensor.From(b, [2, 4])); } } } @@ -322,11 +322,11 @@ public void TestClamp() var input = new float[] { 1, 2, 3, 4, 5, 6, 7, 8 }; var min = 3f; var max = 6f; - var expr = IR.F.Math.Clamp(Tensor.From(input, new[] { 2, 4 }), min, max); + var expr = IR.F.Math.Clamp(Tensor.From(input, [2, 4]), min, max); CompilerServices.InferenceType(expr); var result = new float[] { 3, 3, 3, 4, 5, 6, 6, 6 }; - var expect = Tensor.From(result, new[] { 2, 4 }); + var expect = Tensor.From(result, [2, 4]); Assert.Equal(expect, expr.Evaluate().AsTensor()); } @@ -336,14 +336,14 @@ public void TestClampInvalidType() var input = new float[] { 1, 2, 3, 4, 5, 6, 7, 8 }; var min = 3U; var max = 6L; - var expr = IR.F.Math.Clamp(Tensor.From(input, new[] { 2, 4 }), min, max); + var expr = IR.F.Math.Clamp(Tensor.From(input, [2, 4]), min, max); CompilerServices.InferenceType(expr); Assert.IsType(expr.CheckedType); } [Theory] [MemberData(nameof(ClampInvalidTypeData))] - public void TestClampInvalidType2(int[] inputShape, int[] minShape, int[] maxShape) + public void TestClampInvalidType2(long[] inputShape, long[] minShape, long[] maxShape) { var input = Tensor.FromScalar(3.3f, inputShape); var min = Tensor.FromScalar(0.0f, minShape); @@ -394,9 +394,9 @@ public void TestCompare() var ort_a = OrtKISharp.Tensor.MakeTensor(a, new long[] { 2, 4 }); var ort_b = OrtKISharp.Tensor.MakeTensor(b, new long[] { 2, 4 }); - var expr_a = Tensor.From(a, new[] { 2, 4 }); - var expr_b = Tensor.From(b, new[] { 2, 4 }); - _ = Tensor.From(result, new[] { 2, 4 }).ToOrtTensor(); + var expr_a = Tensor.From(a, [2, 4]); + var expr_b = Tensor.From(b, [2, 4]); + _ = Tensor.From(result, [2, 4]).ToOrtTensor(); var ops = new CompareOp[] { @@ -445,7 +445,7 @@ public void TestCumsum() var input1 = OrtKISharp.Tensor.MakeTensor(input, new long[] { 2, 4 }); var expect = OrtKI.CumSum(input1, axis, exclusive ? 1L : 0L, reverse ? 1L : 0L); - var input2 = Tensor.From(input, new[] { 2, 4 }); + var input2 = Tensor.From(input, [2, 4]); var expr = IR.F.Tensors.CumSum(input2, axis, exclusive, reverse); CompilerServices.InferenceType(expr); Assert.Equal(expect, expr.Evaluate().AsTensor().ToOrtTensor()); @@ -463,7 +463,7 @@ public void TestDequantize() var expect = OrtKI.DequantizeLinear(input1, scale, zero_point, axis); var quant_param = new QuantParam(zero_point, scale); - var input2 = Tensor.From(input, new[] { 2, 4 }); + var input2 = Tensor.From(input, [2, 4]); var expr = IR.F.Math.Dequantize(input2, quant_param, DataTypes.Float32); CompilerServices.InferenceType(expr); Assert.Equal(expect, expr.Evaluate().AsTensor().ToOrtTensor()); @@ -481,7 +481,7 @@ public void TestQuantize() var expect = OrtKI.QuantizeLinear(input1, scale, zero_point, axis); var quantParam = new QuantParam(zero_point, scale); - var input2 = Tensor.From(input, new[] { 2, 4 }); + var input2 = Tensor.From(input, [2, 4]); var expr = IR.F.Math.Quantize(input2, quantParam, DataTypes.UInt8); CompilerServices.InferenceType(expr); Assert.Equal(expect, expr.Evaluate().AsTensor().ToOrtTensor()); @@ -503,7 +503,7 @@ public void TestInt16Quantize() (int)DataTypes.Int16.ToOrtType()); var quantParam = new QuantParam(zeroPoint, scale); - var input2 = Tensor.From(input, new[] { 2, 4 }); + var input2 = Tensor.From(input, [2, 4]); var expr = IR.F.Math.Quantize(input2, quantParam, DataTypes.Int16); CompilerServices.InferenceType(expr); Assert.Equal(expect, expr.Evaluate().AsTensor().ToOrtTensor()); @@ -516,9 +516,9 @@ public void TestFakeDequantize() byte zero_point = 127; var scale = 0.01F; - var expect = Tensor.From(input, new[] { 2, 4 }); + var expect = Tensor.From(input, [2, 4]); var expr = IR.F.Math.FakeDequantize( - Tensor.From(input, new[] { 2, 4 }), + Tensor.From(input, [2, 4]), new QuantParam(zero_point, scale), DataTypes.Float32); CompilerServices.InferenceType(expr); @@ -532,9 +532,9 @@ public void TestFakeQuantize() byte zero_point = 127; var scale = 0.05F; - var expect = Tensor.From(input, new[] { 2, 4 }); + var expect = Tensor.From(input, [2, 4]); var expr = IR.F.Math.FakeQuantize( - Tensor.From(input, new[] { 2, 4 }), + Tensor.From(input, [2, 4]), new QuantParam(zero_point, scale), DataTypes.UInt8); CompilerServices.InferenceType(expr); @@ -549,8 +549,8 @@ public void TestMatmul() var m1_ort = OrtKISharp.Tensor.MakeTensor(input, new long[] { 2, 4 }); var m2_ort = OrtKISharp.Tensor.MakeTensor(input, new long[] { 4, 2 }); - var m1 = Tensor.From(input, new[] { 2, 4 }); - var m2 = Tensor.From(input, new[] { 4, 2 }); + var m1 = Tensor.From(input, [2, 4]); + var m2 = Tensor.From(input, [4, 2]); var expect = OrtKI.MatMul(m1_ort, m2_ort); var expr = IR.F.Math.MatMul(m1, m2); @@ -565,8 +565,8 @@ public void TestMatmul() var m1_ort = OrtKISharp.Tensor.MakeTensor(input1, new long[] { 2, 2, 4 }); var m2_ort = OrtKISharp.Tensor.MakeTensor(input2, new long[] { 4, 2 }); - var m1 = Tensor.From(input1, new[] { 2, 2, 4 }); - var m2 = Tensor.From(input2, new[] { 4, 2 }); + var m1 = Tensor.From(input1, [2, 2, 4]); + var m2 = Tensor.From(input2, [4, 2]); var expect = OrtKI.MatMul(m1_ort, m2_ort); var expr = IR.F.Math.MatMul(m1, m2); @@ -581,8 +581,8 @@ public void TestMatmulInvalidType() { var input1 = new float[] { 1.0F, 1.2F, 1.4F, 1.5F, 1.6F, 1.8F, 1.9F, 2.0F }; var input2 = new int[] { 1, 2, 3, 4, 5, 6, 7, 8 }; - var m1 = Tensor.From(input1, new[] { 2, 4 }); - var m2 = Tensor.From(input2, new[] { 4, 2 }); + var m1 = Tensor.From(input1, [2, 4]); + var m2 = Tensor.From(input2, [4, 2]); var expr = IR.F.Math.MatMul(m1, m2); CompilerServices.InferenceType(expr); @@ -592,8 +592,8 @@ public void TestMatmulInvalidType() { var input1 = new float[] { 1.0F, 1.2F, 1.4F, 1.5F, 1.6F, 1.8F, 1.9F, 2.0F }; var input2 = new float[] { 1.0F, 1.2F, 1.4F, 1.5F, 1.6F, 1.8F, 1.9F, 2.0F }; - var m1 = Tensor.From(input1, new[] { 2, 4 }); - var m2 = Tensor.From(input2, new[] { 1, 8 }); + var m1 = Tensor.From(input1, [2, 4]); + var m2 = Tensor.From(input2, [1, 8]); var expr = IR.F.Math.MatMul(m1, m2); CompilerServices.InferenceType(expr); @@ -638,8 +638,8 @@ public void TestReduce() var a = new float[] { 1, 2, 3, 4, 5, 6, 7, 8 }; var result = new float[] { 5, 6, 7, 8 }; var ort_a = OrtKISharp.Tensor.MakeTensor(a, new long[] { 2, 4 }); - var expr_a = Tensor.From(a, new[] { 2, 4 }); - _ = Tensor.From(result, new[] { 1, 4 }).ToOrtTensor(); + var expr_a = Tensor.From(a, [2, 4]); + _ = Tensor.From(result, [1, 4]).ToOrtTensor(); var ops = new ReduceOp[] { ReduceOp.Max, ReduceOp.Min, ReduceOp.Mean, ReduceOp.Prod, ReduceOp.Sum }; @@ -655,7 +655,7 @@ public void TestReduce() _ => throw new ArgumentOutOfRangeException(nameof(op)), }; - var expr = IR.F.Tensors.Reduce(op, expr_a, Tensor.From(axes, new[] { 1 }), initValue, keepDims); + var expr = IR.F.Tensors.Reduce(op, expr_a, Tensor.From(axes, [1]), initValue, keepDims); CompilerServices.InferenceType(expr); Assert.Equal(expect, expr.Evaluate().AsTensor().ToOrtTensor()); } @@ -670,8 +670,8 @@ public void TestReduceArg() var a = new float[] { 1, 2, 3, 4, 5, 6, 7, 8 }; var result = new int[] { 5, 6, 7, 8 }; var ort_a = OrtKISharp.Tensor.MakeTensor(a, new long[] { 2, 4 }); - var expr_a = Tensor.From(a, new[] { 2, 4 }); - _ = Tensor.From(result, new[] { 1, 4 }).ToOrtTensor(); + var expr_a = Tensor.From(a, [2, 4]); + _ = Tensor.From(result, [1, 4]).ToOrtTensor(); var ops = new ReduceArgOp[] { ReduceArgOp.ArgMax, ReduceArgOp.ArgMin }; @@ -759,13 +759,13 @@ public void TestUnary() var f = new float[] { 1F, 1.1F, 1.2F, 1.3F }; foreach (var op in ops) { - TestUnaryNormal(op, OrtKISharp.Tensor.MakeTensor(f, new long[] { 2, 2 }), Tensor.From(f, new[] { 2, 2 })); + TestUnaryNormal(op, OrtKISharp.Tensor.MakeTensor(f, new long[] { 2, 2 }), Tensor.From(f, [2, 2])); } } { bool[] a = new bool[] { true, false, false, true }; - TestUnaryNormal(UnaryOp.LogicalNot, OrtKISharp.Tensor.MakeTensor(a, new long[] { 2, 2 }), Tensor.From(a, new[] { 2, 2 })); + TestUnaryNormal(UnaryOp.LogicalNot, OrtKISharp.Tensor.MakeTensor(a, new long[] { 2, 2 }), Tensor.From(a, [2, 2])); var expr = IR.F.Math.Unary(UnaryOp.BitwiseNot, a); CompilerServices.InferenceType(expr); diff --git a/src/Nncase.Tests/Evaluator/UnitTestEvaluatorNN.cs b/src/Nncase.Tests/Evaluator/UnitTestEvaluatorNN.cs index 802fc6bfe..bac6a125d 100755 --- a/src/Nncase.Tests/Evaluator/UnitTestEvaluatorNN.cs +++ b/src/Nncase.Tests/Evaluator/UnitTestEvaluatorNN.cs @@ -167,15 +167,15 @@ public void TestActivationGelu() public void TestBatchToSpace() { var a = new float[] { 1, 3, 9, 11, 2, 4, 10, 12, 5, 7, 13, 15, 6, 8, 14, 16 }; - var input = Tensor.From(a, new[] { 4, 1, 2, 2 }); + var input = Tensor.From(a, [4, 1, 2, 2]); var shape = new long[] { 2, 2 }; var b = new float[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 }; - var expect = Tensor.From(b, new[] { 1, 1, 4, 4 }); + var expect = Tensor.From(b, [1, 1, 4, 4]); var crops = new long[] { 0, 0, 0, 0 }; var expr = IR.F.NN.BatchToSpace( input, - Tensor.From(shape, new[] { 2 }), - Tensor.From(crops, new[] { 2, 2 })); + Tensor.From(shape, [2]), + Tensor.From(crops, [2, 2])); CompilerServices.InferenceType(expr); Assert.Equal(expect, expr.Evaluate().AsTensor()); } @@ -202,7 +202,7 @@ public void TestConv2D() weight.ToTensor(), bias.ToTensor(), stride: new[] { 1, 1 }, - padding: Tensor.From(new int[] { 1, 1, 1, 1 }, new[] { 2, 2 }), + padding: Tensor.From(new int[] { 1, 1, 1, 1 }, [2, 2]), dilation: new[] { 1, 1 }, PadMode.Constant, 1); @@ -235,7 +235,7 @@ public void TestConv2D_1() weight.ToTensor(), bias.ToTensor(), stride: new[] { 1, 1 }, - padding: Tensor.From(new int[] { 1, 1, 1, 1 }, new[] { 2, 2 }), + padding: Tensor.From(new int[] { 1, 1, 1, 1 }, [2, 2]), dilation: new[] { 1, 1 }, Nncase.PadMode.Constant, 1, @@ -270,8 +270,8 @@ public void TestConv2DTranspose() bias.ToTensor(), outShape, stride: new[] { 1, 1 }, - padding: Tensor.From(new long[] { 1, 1, 1, 1 }, new[] { 4 }), - outputPadding: Tensor.From(new long[] { 0, 0 }, new[] { 2 }), + padding: Tensor.From(new long[] { 1, 1, 1, 1 }, [4]), + outputPadding: Tensor.From(new long[] { 0, 0 }, [2]), dilation: new[] { 1, 1 }, PadMode.Constant, 1); @@ -347,14 +347,14 @@ public void TestL2Normalization() var a = new float[] { 0F, 2F, 3F, 2F, 2F, 2F }; var b = new float[] { 0F, 0.4F, 0.6F, 0.4F, 0.4F, 0.4F }; { - var expect = Tensor.From(b, new[] { 6 }); - var input = Tensor.From(a, new[] { 6 }); + var expect = Tensor.From(b, [6]); + var input = Tensor.From(a, [6]); DoL2Normalization(expect, input); } { - var expect = Tensor.From(b, new[] { 1, 2, 3 }); - var input = Tensor.From(a, new[] { 1, 2, 3 }); + var expect = Tensor.From(b, [1, 2, 3]); + var input = Tensor.From(a, [1, 2, 3]); DoL2Normalization(expect, input); } } @@ -422,7 +422,7 @@ public void TestLRN() public void TestOneHotTF() { var a = new int[] { 1, 2, 0, 3 }; - var indices = Tensor.From(a, new[] { 4 }); + var indices = Tensor.From(a, [4]); var depth = 5; var values = Tensor.From(new float[] { 0, 1 }, new Shape(new[] { 2 })); var axis = 0L; @@ -439,7 +439,7 @@ public void TestOneHotTF() public void TestOneHotOnnx() { var a = new float[] { 1, 2, 0, 3 }; - var indices = Tensor.From(a, new[] { 4 }); + var indices = Tensor.From(a, [4]); var depth = 5F; var values = Tensor.From(new float[] { 0, 1 }, new Shape(new[] { 2 })); var axis = 1L; @@ -654,16 +654,16 @@ public void TestSoftsign() public void TestSpaceToBatch() { var a = new float[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 }; - var input = Tensor.From(a, new[] { 1, 4, 4, 1 }); + var input = Tensor.From(a, [1, 4, 4, 1]); var shape = new long[] { 2, 2 }; var output = new float[] { 1, 3, 9, 11, 2, 4, 10, 12, 5, 7, 13, 15, 6, 8, 14, 16 }; - var expect = Tensor.From(output, new[] { 4, 2, 2, 1 }); + var expect = Tensor.From(output, [4, 2, 2, 1]); var crops = new long[] { 0, 0, 0, 0 }; var expr = NCHWToNHWC(IR.F.NN.SpaceToBatch( NHWCToNCHW(input).Evaluate().AsTensor(), - Tensor.From(shape, new[] { 2 }), - Tensor.From(crops, new[] { 2, 2 }))); + Tensor.From(shape, [2]), + Tensor.From(crops, [2, 2]))); CompilerServices.InferenceType(expr); Assert.Equal(expect, expr.Evaluate().AsTensor()); } diff --git a/src/Nncase.Tests/Evaluator/UnitTestEvaluatorTensors.cs b/src/Nncase.Tests/Evaluator/UnitTestEvaluatorTensors.cs index 889f56e48..f4a69a9e5 100644 --- a/src/Nncase.Tests/Evaluator/UnitTestEvaluatorTensors.cs +++ b/src/Nncase.Tests/Evaluator/UnitTestEvaluatorTensors.cs @@ -523,7 +523,7 @@ public void TestStack3() var expr = Tensors.Stack(inputList, 0); CompilerServices.InferenceType(expr); var ret = expr.Evaluate().AsTensor(); - Assert.Equal(new[] { 1, 2, 2 }, ret.Shape.ToValueArray()); + Assert.Equal([1, 2, 2], ret.Shape.ToValueArray()); } { @@ -531,7 +531,7 @@ public void TestStack3() var expr = Tensors.Stack(inputList, 1); CompilerServices.InferenceType(expr); var ret = expr.Evaluate().AsTensor(); - Assert.Equal(new[] { 2, 1, 2 }, ret.Shape.ToValueArray()); + Assert.Equal([2, 1, 2], ret.Shape.ToValueArray()); } { @@ -539,7 +539,7 @@ public void TestStack3() var expr = Tensors.Stack(inputList, 2); CompilerServices.InferenceType(expr); var ret = expr.Evaluate().AsTensor(); - Assert.Equal(new[] { 2, 2, 1 }, ret.Shape.ToValueArray()); + Assert.Equal([2, 2, 1], ret.Shape.ToValueArray()); } } @@ -596,7 +596,7 @@ public void TestTile2() [Fact] public void TestGatherND() { - var shape = new[] { 2, 2 }; + var shape = new long[] { 2, 2 }; var input = new Tensor(new[] { 0, 1, 2, 3 }, shape); var indices = new Tensor(new[] { 0L, 0L, 1L, 1L }, shape); long batchDims = 0L; @@ -610,10 +610,10 @@ public void TestGatherND() [Fact] public void TestScatterND() { - var shape = new[] { 2, 1, 10 }; + var shape = new long[] { 2, 1, 10 }; var input = Tensor.FromScalar(0f, shape); - var indices = new Tensor(new[] { 0L, 0L, 1L, 1L, 0L, 1L }, new[] { 2, 1, 1, 3 }); - var updates = new Tensor(new[] { 5f, 10f }, new[] { 2, 1, 1 }); + var indices = new Tensor(new[] { 0L, 0L, 1L, 1L, 0L, 1L }, [2, 1, 1, 3]); + var updates = new Tensor(new[] { 5f, 10f }, [2, 1, 1]); // var expect = OrtKI.ScatterND(input.ToOrtTensor(), indices.ToOrtTensor(), updates.ToOrtTensor(), "none"); var expect = Tensor.FromScalar(0f, shape); @@ -628,7 +628,7 @@ public void TestScatterND() [Fact] public void TestGather() { - var shape = new[] { 2, 2 }; + var shape = new long[] { 2, 2 }; var input = new Tensor(new[] { 0, 1, 2, 3 }, shape); var indices = new Tensor(new[] { 0L, 0L, 1L, 1L }, shape); long batchDims = 0L; @@ -642,7 +642,7 @@ public void TestGather() [Fact] public void TestGatherElements() { - var shape = new[] { 2, 2 }; + var shape = new long[] { 2, 2 }; var input = new Tensor(new[] { 1, 2, 3, 4 }, shape); var indices = new Tensor(new[] { 0L, 0L, 1L, 0L }, shape); long axis = 1L; @@ -688,7 +688,7 @@ public void TestTopK() public void TestWhere() { var shape = new long[] { 2, 2 }; - var con = new Tensor(new[] { true, false, true, true }, new[] { 2, 2 }); + var con = new Tensor(new[] { true, false, true, true }, [2, 2]); var x = OrtKI.Random(shape); var y = OrtKI.Random(shape); var expect = OrtKI.Where(con.ToOrtTensor(), x, y); @@ -715,7 +715,7 @@ public void TestReduceMean() long axis = 0L; long keepDims = 0L; var a = new float[] { 1, 2, 3, 4, 5, 6, 7, 8 }; - var expr_a = Tensor.From(a, new[] { 2, 4 }); + var expr_a = Tensor.From(a, [2, 4]); var expr = IR.F.Tensors.ReduceMean(expr_a, axis, 0f, keepDims); CompilerServices.InferenceType(expr); var expect = Reduce(ReduceOp.Mean, expr_a, axis, 0f, keepDims); @@ -729,7 +729,7 @@ public void TestReduceMin() long axis = 0L; long keepDims = 0L; var a = new float[] { 1, 2, 3, 4, 5, 6, 7, 8 }; - var expr_a = Tensor.From(a, new[] { 2, 4 }); + var expr_a = Tensor.From(a, [2, 4]); var expr = IR.F.Tensors.ReduceMin(expr_a, axis, 0f, keepDims); CompilerServices.InferenceType(expr); var expect = Reduce(ReduceOp.Min, expr_a, axis, 0f, keepDims); @@ -743,7 +743,7 @@ public void TestReduceMax() long axis = 0L; long keepDims = 0L; var a = new float[] { 1, 2, 3, 4, 5, 6, 7, 8 }; - var expr_a = Tensor.From(a, new[] { 2, 4 }); + var expr_a = Tensor.From(a, [2, 4]); var expr = IR.F.Tensors.ReduceMax(expr_a, axis, 0f, keepDims); CompilerServices.InferenceType(expr); var expect = Reduce(ReduceOp.Max, expr_a, axis, 0f, keepDims); @@ -757,7 +757,7 @@ public void TestReduceSum() long axis = 0L; long keepDims = 0L; var a = new float[] { 1, 2, 3, 4, 5, 6, 7, 8 }; - var expr_a = Tensor.From(a, new[] { 2, 4 }); + var expr_a = Tensor.From(a, [2, 4]); var expr = IR.F.Tensors.ReduceSum(expr_a, axis, 0f, keepDims); CompilerServices.InferenceType(expr); var expect = Reduce(ReduceOp.Sum, expr_a, axis, 0f, keepDims); diff --git a/src/Nncase.Tests/Evaluator/UnitTestShapeEvaluator.cs b/src/Nncase.Tests/Evaluator/UnitTestShapeEvaluator.cs index 9bbbf554f..84c89b74f 100644 --- a/src/Nncase.Tests/Evaluator/UnitTestShapeEvaluator.cs +++ b/src/Nncase.Tests/Evaluator/UnitTestShapeEvaluator.cs @@ -48,7 +48,7 @@ public void TestConstant2() [Fact] public void TestWithVar() { - var input = new Var(new TensorType(DataTypes.Float32, new[] { 1, 3, Dimension.Unknown, 6 })); + var input = new Var(new TensorType(DataTypes.Float32, new[] { 1, 3, Dimension.Unknown(), 6 })); var dimVar = new Var(new TensorType(DataTypes.Int32, Shape.Scalar)); var newShape = new Expr[] { 1, 3, dimVar, 6 }; var varMap = new Dictionary { { input, newShape } }; @@ -182,12 +182,12 @@ public void UnitTestReshape() public void UnitTestGetItem() { var dimVar = new Var(new TensorType(DataTypes.Int32, Shape.Scalar)); - var input = new Var(new TensorType(DataTypes.Int32, new[] { Dimension.Unknown })); + var input = new Var(new TensorType(DataTypes.Int32, new[] { Dimension.Unknown() })); var expr = input[1]; var dict = new Dictionary { { input, new[] { dimVar } } }; var shape = expr.EvaluateShapeExpr(dict); var varValues = new Dictionary { { input, Value.FromTensor(new[] { 4 }) } }; - var shapeValue = shape.Evaluate(varValues).AsTensor().ToArray(); + var shapeValue = shape.Evaluate(varValues).AsTensor().ToArray(); var evalShape = expr .Evaluate(new Dictionary { { input, Value.FromTensor(new[] { 2, 3, 4, 5 }) } }) .AsTensor() @@ -200,12 +200,12 @@ public void UnitTestGetItem() public void UnitTestGetItemSingle() { var dimVar = new Var(new TensorType(DataTypes.Int32, Shape.Scalar)); - var input = new Var(new TensorType(DataTypes.Int32, new[] { Dimension.Unknown })); + var input = new Var(new TensorType(DataTypes.Int32, new[] { Dimension.Unknown() })); var expr = input[0]; var dict = new Dictionary { { input, new[] { dimVar } } }; var shape = expr.EvaluateShapeExpr(dict); var varValues = new Dictionary { { input, Value.FromTensor(new[] { 1 }) } }; - var shapeValue = shape.Evaluate(varValues).AsTensor().ToArray(); + var shapeValue = shape.Evaluate(varValues).AsTensor().ToArray(); var evalShape = expr .Evaluate(new Dictionary { { input, Value.FromTensor(new[] { 2 }) } }) .AsTensor() @@ -260,14 +260,14 @@ public void UnitTestPad() public void TestSpaceTobatch() { var dimVar = new Var(new TensorType(DataTypes.Int32, Shape.Scalar)); - var input = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown, 192 })); - var paddings = Tensor.From(new[] { 0, 1 }, new[] { 1, 2 }); + var input = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown(), 192 })); + var paddings = Tensor.From(new[] { 0, 1 }, [1, 2]); var expr = NCHWToNHWC(SpaceToBatch(NHWCToNCHW(input), new[] { 3 }, paddings)); var dict = new Dictionary { { input, new Expr[] { 1, dimVar, 192 } } }; var shape = expr.EvaluateShapeExpr(dict); var varValues = new Dictionary { { dimVar, Value.FromTensor(8) } }; Dumpper.DumpIR(shape, "Shape"); - var shapeValue = shape.Evaluate(varValues).AsTensor().ToArray(); + var shapeValue = shape.Evaluate(varValues).AsTensor().ToArray(); var evalShape = expr .Evaluate(new Dictionary { { input, Value.FromTensor(Testing.Rand(1, 8, 192)) } }) .AsTensor() @@ -280,14 +280,14 @@ public void TestSpaceTobatch() public void TestBatchToSpace() { var dimVar = new Var(new TensorType(DataTypes.Int32, Shape.Scalar)); - var input = new Var(new TensorType(DataTypes.Float32, new[] { Dimension.Unknown, 69, 192 })); - var paddings = Tensor.From(new[] { 0, 1 }, new[] { 1, 2 }); + var input = new Var(new TensorType(DataTypes.Float32, new[] { Dimension.Unknown(), 69, 192 })); + var paddings = Tensor.From(new[] { 0, 1 }, [1, 2]); var expr = BatchToSpace(input, new[] { 3 }, paddings); var dict = new Dictionary { { input, new Expr[] { dimVar, 69, 192 } } }; var shape = expr.EvaluateShapeExpr(dict); var varValues = new Dictionary { { dimVar, Value.FromTensor(3) } }; Dumpper.DumpIR(shape, "Shape"); - var shapeValue = shape.Evaluate(varValues).AsTensor().ToArray(); + var shapeValue = shape.Evaluate(varValues).AsTensor().ToArray(); var evalShape = expr .Evaluate(new Dictionary { { input, Value.FromTensor(Testing.Rand(3, 69, 192)) } }) .AsTensor() @@ -319,7 +319,7 @@ public void UnitTestRange(int beginV, int endV, int stepV) { step, Value.FromTensor(stepV) }, }; - var shapeValue = shape.Evaluate(varValues).AsTensor().ToArray(); + var shapeValue = shape.Evaluate(varValues).AsTensor().ToArray(); var fixedShape = expr.Evaluate(varValues).AsTensor().Shape.ToValueArray(); Assert.Equal(fixedShape, shapeValue); } @@ -339,7 +339,7 @@ private void TestOpShapeEval(Func exprCtor, Var input, Expr[] newSha var expr = exprCtor(input); var shape = expr.EvaluateShapeExpr(varMap); var varValues = newShape.Where(x => x is Var).ToDictionary(x => (Var)x, _ => (IValue)Value.FromTensor(_defaultDim)); - var shapeValue = shape.Evaluate(varValues).AsTensor().ToArray(); + var shapeValue = shape.Evaluate(varValues).AsTensor().ToArray(); var fixedShape = newShape.Select(x => { @@ -357,7 +357,7 @@ private void TestOpShapeEval(Func exprCtor, Var input, Expr[] newSha private void TestOpShapeEval(Func exprCtor) { - var (input, newShape) = MakeInput(new[] { 1, 3, Dimension.Unknown, 24 }); + var (input, newShape) = MakeInput(new[] { 1, 3, Dimension.Unknown(), 24 }); TestOpShapeEval(exprCtor, input, newShape); } } diff --git a/src/Nncase.Tests/Quant/UnitTestAddRangeOfMarker.cs b/src/Nncase.Tests/Quant/UnitTestAddRangeOfMarker.cs index 204fb0392..f55083bbb 100644 --- a/src/Nncase.Tests/Quant/UnitTestAddRangeOfMarker.cs +++ b/src/Nncase.Tests/Quant/UnitTestAddRangeOfMarker.cs @@ -103,7 +103,7 @@ public SolidCalibrationDatasetProvider(IEnumerable vars) CompilerServices.InferenceType(var); var shape = var.CheckedShape.Select(d => d.IsUnknown ? 1 : d.FixedValue).ToArray(); - var shapeSize = 1; + long shapeSize = 1; for (int j = 0; j < shape.Length; j++) { shapeSize *= shape[j]; diff --git a/src/Nncase.Tests/Quant/UnitTestDumpQuantError.cs b/src/Nncase.Tests/Quant/UnitTestDumpQuantError.cs index 985ad2eae..94d6ab122 100644 --- a/src/Nncase.Tests/Quant/UnitTestDumpQuantError.cs +++ b/src/Nncase.Tests/Quant/UnitTestDumpQuantError.cs @@ -34,12 +34,12 @@ public async Task TestDumpQuantError() weightsValue.Add(i * 1.0f / (32 * 3 * 3 * 3)); } - Expr weights = Tensor.From(weightsValue.ToArray(), new[] { 32, 3, 3, 3 }); + Expr weights = Tensor.From(weightsValue.ToArray(), [32, 3, 3, 3]); weights.Metadata.OutputNames = new string[] { "weight" }; var bias = Normal(DataTypes.Float32, new[] { 32 }).Evaluate().AsTensor(); - var stride = Tensor.From(new[] { 1, 1 }, new[] { 2 }); - var dilation = Tensor.From(new[] { 1, 1 }, new[] { 2 }); + var stride = Tensor.From(new[] { 1, 1 }, [2]); + var dilation = Tensor.From(new[] { 1, 1 }, [2]); var padding = new[,] { { 0, 0 }, { 0, 0 } }; var conv = Conv2D(input, weights, bias, stride, padding, dilation, PadMode.Constant, 1); @@ -62,12 +62,12 @@ public async Task TestDumpQuantErrorFromConfig() weightsValue.Add(i * 1.0f / (32 * 3 * 3 * 3)); } - Expr weights = Tensor.From(weightsValue.ToArray(), new[] { 32, 3, 3, 3 }); + Expr weights = Tensor.From(weightsValue.ToArray(), [32, 3, 3, 3]); weights.Metadata.OutputNames = new string[] { "weight" }; var bias = Normal(DataTypes.Float32, new[] { 32 }).Evaluate().AsTensor(); - var stride = Tensor.From(new[] { 1, 1 }, new[] { 2 }); - var dilation = Tensor.From(new[] { 1, 1 }, new[] { 2 }); + var stride = Tensor.From(new[] { 1, 1 }, [2]); + var dilation = Tensor.From(new[] { 1, 1 }, [2]); var padding = new[,] { { 0, 0 }, { 0, 0 } }; var conv = Conv2D(input, weights, bias, stride, padding, dilation, PadMode.Constant, 1); @@ -128,7 +128,7 @@ public SolidCalibrationDatasetProvider(IEnumerable vars) CompilerServices.InferenceType(var); var shape = var.CheckedShape.Select(d => d.IsUnknown ? 1 : d.FixedValue).ToArray(); - var shapeSize = 1; + long shapeSize = 1; for (int j = 0; j < shape.Length; j++) { shapeSize *= shape[j]; diff --git a/src/Nncase.Tests/Quant/UnitTestExportQuantScheme.cs b/src/Nncase.Tests/Quant/UnitTestExportQuantScheme.cs index 8f2dd6268..8e124cb61 100644 --- a/src/Nncase.Tests/Quant/UnitTestExportQuantScheme.cs +++ b/src/Nncase.Tests/Quant/UnitTestExportQuantScheme.cs @@ -32,12 +32,12 @@ public async Task TestExportQuantSchemeForWeightsByTensorConv2D() weightsValue.Add(i * 1.0f / (32 * 3 * 3 * 3)); } - Expr weights = Tensor.From(weightsValue.ToArray(), new[] { 32, 3, 3, 3 }); + Expr weights = Tensor.From(weightsValue.ToArray(), [32, 3, 3, 3]); weights.Metadata.OutputNames = new string[] { "weight" }; var bias = Normal(DataTypes.Float32, new[] { 32 }).Evaluate().AsTensor(); - var stride = Tensor.From(new[] { 1, 1 }, new[] { 2 }); - var dilation = Tensor.From(new[] { 1, 1 }, new[] { 2 }); + var stride = Tensor.From(new[] { 1, 1 }, [2]); + var dilation = Tensor.From(new[] { 1, 1 }, [2]); var padding = new[,] { { 0, 0 }, { 0, 0 } }; var conv = Conv2D(input, weights, bias, stride, padding, dilation, PadMode.Constant, 1); @@ -62,12 +62,12 @@ public async Task TestExportQuantSchemeForWeightsByChannelConv2D() weightsValue.Add(i * 1.0f / (3 * 3 * 3 * 3)); } - Expr weights = Tensor.From(weightsValue.ToArray(), new[] { 3, 3, 3, 3 }); + Expr weights = Tensor.From(weightsValue.ToArray(), [3, 3, 3, 3]); weights.Metadata.OutputNames = new string[] { "weight" }; var bias = Normal(DataTypes.Float32, new[] { 3 }).Evaluate().AsTensor(); - var stride = Tensor.From(new[] { 1, 1 }, new[] { 2 }); - var dilation = Tensor.From(new[] { 1, 1 }, new[] { 2 }); + var stride = Tensor.From(new[] { 1, 1 }, [2]); + var dilation = Tensor.From(new[] { 1, 1 }, [2]); var padding = new[,] { { 0, 0 }, { 0, 0 } }; var conv = Conv2D(input, weights, bias, stride, padding, dilation, PadMode.Constant, 1); @@ -131,7 +131,7 @@ public SolidCalibrationDatasetProvider(IEnumerable vars) CompilerServices.InferenceType(var); var shape = var.CheckedShape.Select(d => d.IsUnknown ? 1 : d.FixedValue).ToArray(); - var shapeSize = 1; + long shapeSize = 1; for (int j = 0; j < shape.Length; j++) { shapeSize *= shape[j]; diff --git a/src/Nncase.Tests/Quant/UnitTestImportQuantScheme.cs b/src/Nncase.Tests/Quant/UnitTestImportQuantScheme.cs index bff7d7009..ed3d1f0fa 100644 --- a/src/Nncase.Tests/Quant/UnitTestImportQuantScheme.cs +++ b/src/Nncase.Tests/Quant/UnitTestImportQuantScheme.cs @@ -46,12 +46,12 @@ public async Task TestImportQuantSchemeForConv2D() weightsValue.Add(i * 1.0f / (32 * 3 * 3 * 3)); } - Expr weights = Tensor.From(weightsValue.ToArray(), new[] { 32, 3, 3, 3 }); + Expr weights = Tensor.From(weightsValue.ToArray(), [32, 3, 3, 3]); weights.Metadata.OutputNames = new string[] { "weight" }; var bias = Normal(DataTypes.Float32, new[] { 32 }).Evaluate().AsTensor(); - var stride = Tensor.From(new[] { 1, 1 }, new[] { 2 }); - var dilation = Tensor.From(new[] { 1, 1 }, new[] { 2 }); + var stride = Tensor.From(new[] { 1, 1 }, [2]); + var dilation = Tensor.From(new[] { 1, 1 }, [2]); var padding = new[,] { { 0, 0 }, { 0, 0 } }; var conv = Conv2D(input, weights, bias, stride, padding, dilation, PadMode.Constant, 1); @@ -120,7 +120,7 @@ public SolidCalibrationDatasetProvider(IEnumerable vars) CompilerServices.InferenceType(var); var shape = var.CheckedShape.Select(d => d.IsUnknown ? 1 : d.FixedValue).ToArray(); - var shapeSize = 1; + long shapeSize = 1; for (int j = 0; j < shape.Length; j++) { shapeSize *= shape[j]; diff --git a/src/Nncase.Tests/Quant/UnitTestQuantAlgorithm.cs b/src/Nncase.Tests/Quant/UnitTestQuantAlgorithm.cs index ab321b0ea..9aa2db317 100644 --- a/src/Nncase.Tests/Quant/UnitTestQuantAlgorithm.cs +++ b/src/Nncase.Tests/Quant/UnitTestQuantAlgorithm.cs @@ -85,10 +85,10 @@ public async Task TestKLQuant() biasValue.Add(((i * 1.0f / 16) - 0.5f) * 2); } - var weights = Tensor.From(weightsValue.ToArray(), new[] { 16, 3, 3, 3 }); - var bias = Tensor.From(biasValue.ToArray(), new[] { 16 }); - var stride = Tensor.From(new[] { 1, 1 }, new[] { 2 }); - var dilation = Tensor.From(new[] { 1, 1 }, new[] { 2 }); + var weights = Tensor.From(weightsValue.ToArray(), [16, 3, 3, 3]); + var bias = Tensor.From(biasValue.ToArray(), [16]); + var stride = Tensor.From(new[] { 1, 1 }, [2]); + var dilation = Tensor.From(new[] { 1, 1 }, [2]); var padding = new[] { new[] { 0, 1 }, new[] { 0, 0 } }; var conv = IR.F.NN.Conv2D(input, weights, bias, stride, Pad(padding), dilation, PadMode.Constant, 1); @@ -201,7 +201,7 @@ public void TestSQuant5() QuantAlgorithmUtility.SquantWeights(weights, range, inputWeightsShape, quantMode, bits, false)); } - private Expr Pad(int[][] p) => Const.FromTensor(Tensor.From(p.SelectMany(i => i).ToArray(), new[] { 2, 2 })); + private Expr Pad(int[][] p) => Const.FromTensor(Tensor.From(p.SelectMany(i => i).ToArray(), [2, 2])); public sealed class DumpVisitor : ExprVisitor { @@ -255,7 +255,7 @@ public SolidCalibrationDatasetProvider(IEnumerable vars) CompilerServices.InferenceType(var); var shape = var.CheckedShape.Select(d => d.IsUnknown ? 1 : d.FixedValue).ToArray(); - var shapeSize = 1; + long shapeSize = 1; for (int j = 0; j < shape.Length; j++) { shapeSize *= shape[j]; diff --git a/src/Nncase.Tests/Rewrite/RewriteBase.cs b/src/Nncase.Tests/Rewrite/RewriteBase.cs index e03b24e32..b4344ced1 100644 --- a/src/Nncase.Tests/Rewrite/RewriteBase.cs +++ b/src/Nncase.Tests/Rewrite/RewriteBase.cs @@ -60,7 +60,7 @@ public static Call Conv2D(Expr input, int in_channels, int out_channels, int ker { var weights = OrtKI.Random(out_channels, in_channels, kernel, kernel).ToTensor(); var bias = OrtKI.Random(out_channels).ToTensor(); - return IR.F.NN.Conv2D(input, weights, bias, new[] { stride, stride }, Tensor.From(new[] { 1, 1, 1, 1 }, new[] { 2, 2 }), new[] { 1, 1 }, PadMode.Constant, 1); + return IR.F.NN.Conv2D(input, weights, bias, new[] { stride, stride }, Tensor.From(new[] { 1, 1, 1, 1 }, [2, 2]), new[] { 1, 1 }, PadMode.Constant, 1); } } @@ -678,7 +678,7 @@ public override Function PreExpr get { var v5 = Transpose(Input, new[] { 0, 2, 3, 1 }); // f32[1,15,20,16] - var v6 = PRelu(v5, Tensor.From(Enumerable.Repeat(0.2f, 16).ToArray(), new[] { 1, 1, 16 })); // f32[1,15,20,16] + var v6 = PRelu(v5, Tensor.From(Enumerable.Repeat(0.2f, 16).ToArray(), [1, 1, 16])); // f32[1,15,20,16] var v7 = Transpose(v6, new[] { 0, 3, 1, 2 }); // f32[1,16,15,20] var v8 = Conv2D( v7, @@ -706,7 +706,7 @@ public override Function PreExpr get { var v5 = Transpose(Input, new[] { 0, 2, 3, 1 }); // f32[1,15,20,16] - var v6 = PRelu(v5, Tensor.From(Enumerable.Repeat(0.2f, 16).ToArray(), new[] { 1, 1, 1, 16 })); // f32[1,15,20,16] + var v6 = PRelu(v5, Tensor.From(Enumerable.Repeat(0.2f, 16).ToArray(), [1, 1, 1, 16])); // f32[1,15,20,16] var v7 = Transpose(v6, new[] { 0, 3, 1, 2 }); // f32[1,16,15,20] var v8 = Conv2D( v7, @@ -1671,7 +1671,7 @@ public Function PreExpr var exclusive = false; var reverse = false; - var input1 = Tensor.From(input, new[] { 2, 4 }); + var input1 = Tensor.From(input, [2, 4]); var expr = IR.F.Tensors.CumSum(input1, axis, exclusive, reverse); return new Function(expr, new Var[] { _input }); } @@ -1888,8 +1888,8 @@ public Function PreExpr bias.ToTensor(), outShape, stride: new[] { 1, 1 }, - padding: Tensor.From(new long[] { 1, 1, 1, 1 }, new[] { 4 }), - outputPadding: Tensor.From(new long[] { 0, 0 }, new[] { 2 }), + padding: Tensor.From(new long[] { 1, 1, 1, 1 }, [4]), + outputPadding: Tensor.From(new long[] { 0, 0 }, [2]), dilation: new[] { 1, 1 }, PadMode.Constant, 1); @@ -1977,7 +1977,7 @@ public Function PreExpr byte zero_point = 127; var scale = 0.01F; var expr = IR.F.Math.FakeDequantize( - Tensor.From(input, new[] { 2, 4 }), + Tensor.From(input, [2, 4]), new QuantParam(zero_point, scale), DataTypes.Float32); return new Function(expr, new Var[] { _input }); @@ -2009,7 +2009,7 @@ public Function PreExpr byte zero_point = 127; var scale = 0.05F; var expr = IR.F.Math.FakeQuantize( - Tensor.From(input, new[] { 2, 4 }), + Tensor.From(input, [2, 4]), new QuantParam(zero_point, scale), DataTypes.UInt8); return new Function(expr, new Var[] { _input }); @@ -2068,7 +2068,7 @@ public Function PreExpr { get { - var shape = new[] { 2, 2 }; + var shape = new long[] { 2, 2 }; var input = new Tensor(new[] { 0, 1, 2, 3 }, shape); var indices = new Tensor(new[] { 0L, 0L, 1L, 1L }, shape); long batchDims = 0L; @@ -2098,7 +2098,7 @@ public Function PreExpr { get { - var shape = new[] { 2, 2 }; + var shape = new long[] { 2, 2 }; var input = new Tensor(new[] { 0, 1, 2, 3 }, shape); var indices = new Tensor(new[] { 0L, 0L, 1L, 1L }, shape); long batchDims = 0L; @@ -2358,7 +2358,7 @@ public Function PreExpr get { var shape = new long[] { 2, 2 }; - var con = new Tensor(new[] { true, false, true, true }, new[] { 2, 2 }); + var con = new Tensor(new[] { true, false, true, true }, [2, 2]); var x = OrtKI.Random(shape); var y = OrtKI.Random(shape); var expr = IR.F.Tensors.Where(con, x.ToTensor(), y.ToTensor()); @@ -2445,14 +2445,14 @@ public Function PreExpr get { var a = new float[] { 1, 3, 9, 11, 2, 4, 10, 12, 5, 7, 13, 15, 6, 8, 14, 16 }; - var input = Tensor.From(a, new[] { 4, 1, 2, 2 }); + var input = Tensor.From(a, [4, 1, 2, 2]); var shape = new long[] { 2, 2 }; _ = new float[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 }; var crops = new long[] { 0, 0, 0, 0 }; var expr = IR.F.NN.BatchToSpace( input, - Tensor.From(shape, new[] { 2 }), - Tensor.From(crops, new[] { 2, 2 })); + Tensor.From(shape, [2]), + Tensor.From(crops, [2, 2])); return new Function(expr, new Var[] { _input }); } } @@ -2480,7 +2480,7 @@ public Function PreExpr { var a = new float[] { 0F, 2F, 3F, 2F, 2F, 2F }; _ = new float[] { 0F, 0.4F, 0.6F, 0.4F, 0.4F, 0.4F }; - var input = Tensor.From(a, new[] { 6 }); + var input = Tensor.From(a, [6]); var expr = IR.F.NN.L2Normalization(input); return new Function(expr, new Var[] { _input }); } @@ -2508,7 +2508,7 @@ public Function PreExpr get { var a = new int[] { 1, 2, 0, 3 }; - var indices = Tensor.From(a, new[] { 4 }); + var indices = Tensor.From(a, [4]); var depth = 5; var values = Tensor.From(new int[] { 0, 1 }, new Shape(new[] { 2 })); var axis = 0L; @@ -2684,8 +2684,8 @@ public Function PreExpr long select_last_idx = 0L; var a = new float[] { 1, 2, 3, 4, 5, 6, 7, 8 }; var result = new int[] { 5, 6, 7, 8 }; - var expr_a = Tensor.From(a, new[] { 2, 4 }); - _ = Tensor.From(result, new[] { 1, 4 }).ToOrtTensor(); + var expr_a = Tensor.From(a, [2, 4]); + _ = Tensor.From(result, [1, 4]).ToOrtTensor(); var expr = IR.F.Tensors.ReduceArg(ReduceArgOp.ArgMax, DataTypes.Int64, expr_a, axis, 0L, select_last_idx); return new Function(expr, new Var[] { _input }); } @@ -2854,7 +2854,7 @@ public PReluTransposeCase() var v0 = Transpose(input, new[] { 0, 3, 1, 2 }); // f32[1,1,33,65] var v1 = IR.F.NN.Conv2D(v0, IR.F.Random.Normal(new[] { 8, 1, 3, 3 }).Evaluate().AsTensor(), new[] { 0f, 0f, 0f, 0f, 0f, 0f, 0f, 0f }, new[] { 1, 1 }, new[,] { { 1, 1 }, { 1, 1 } }, new[] { 1, 1 }, PadMode.Constant, 1, new[] { -float.PositiveInfinity, float.PositiveInfinity }); // f32[1,8,33,65] var v2 = Transpose(v1, new[] { 0, 2, 3, 1 }); // f32[1,33,65,8] - var v3 = PRelu(v2, Tensor.From(new[] { -0.12399824f, -0.03634571f, 0.5353417f, -0.67039806f, 0.91027457f, -1.0752988f, 0.55657554f, -1.1045103f }, new[] { 1, 1, 8 })); // f32[1,33,65,8] + var v3 = PRelu(v2, Tensor.From(new[] { -0.12399824f, -0.03634571f, 0.5353417f, -0.67039806f, 0.91027457f, -1.0752988f, 0.55657554f, -1.1045103f }, [1, 1, 8])); // f32[1,33,65,8] PreExpr = new Function(v3, new[] { input }); } diff --git a/src/Nncase.Tests/Rewrite/UnitTestDataFlowRewrite.cs b/src/Nncase.Tests/Rewrite/UnitTestDataFlowRewrite.cs index 1f25ee26d..1803c407f 100644 --- a/src/Nncase.Tests/Rewrite/UnitTestDataFlowRewrite.cs +++ b/src/Nncase.Tests/Rewrite/UnitTestDataFlowRewrite.cs @@ -174,8 +174,8 @@ public async Task TestYolo20MinStructure() var dilationW = 1; var padH = Util.GetWindowedPadding(inH, fH, strideH, dilationH, true); var padW = Util.GetWindowedPadding(inW, fW, strideW, dilationW, true); - var stride = Tensor.From(new[] { strideH, strideW }, new[] { 2 }); - var dilation = Tensor.From(new[] { dilationH, dilationW }, new[] { 2 }); + var stride = Tensor.From(new[] { strideH, strideW }, [2]); + var dilation = Tensor.From(new[] { dilationH, dilationW }, [2]); var padding = Util.ConcatPadding(padH, padW); var conv = NN.Conv2D( @@ -197,7 +197,7 @@ public async Task TestYolo20MinStructure() var max = Binary(BinaryOp.Max, convAfterTranspose, mul); // ReduceWindow2D - var doubleV = Tensor.From(new[] { 2, 2 }, new[] { 2 }); + var doubleV = Tensor.From(new[] { 2, 2 }, [2]); var initValue = (Const)0; var (rInH, rInW) = Util.GetHW(max); var rPadH = Util.GetWindowedPadding(rInH, 2, 2, dilationH, true); @@ -276,12 +276,12 @@ public async Task TestReshapeToByChannel() Assert.Equal(new long[] { 3, 1, 1 }, afterShape); var b = Reshape(v, afterShape); b.InferenceType(); - Assert.Equal(new[] { 3, 1, 1 }, b.Evaluate().AsTensor().Dimensions.ToArray()); + Assert.Equal([3, 1, 1], b.Evaluate().AsTensor().Dimensions.ToArray()); var a = OnnxImporter.ReshapeToByChannel(v); var after = await RunShapeInferPass("ReshapeToByChannel", a); Assert.True(after.InferenceType()); - Assert.Equal(new[] { 3, 1, 1 }, after.Evaluate().AsTensor().Dimensions.ToArray()); + Assert.Equal([3, 1, 1], after.Evaluate().AsTensor().Dimensions.ToArray()); } [Fact] diff --git a/src/Nncase.Tests/Rewrite/UnitTestEGraphRewrite.cs b/src/Nncase.Tests/Rewrite/UnitTestEGraphRewrite.cs index 82050273a..a2dae5e42 100644 --- a/src/Nncase.Tests/Rewrite/UnitTestEGraphRewrite.cs +++ b/src/Nncase.Tests/Rewrite/UnitTestEGraphRewrite.cs @@ -114,8 +114,8 @@ public void TestClassicDemo() [Fact] public void TestTransposeBinaryMotion() { - var c0 = (Call)NHWCToNCHW(Tensor.FromScalar(1, new[] { 2, 2, 3, 4 })); - var c1 = (Call)NHWCToNCHW(Tensor.FromScalar(1, new[] { 2, 2, 1, 1 })); + var c0 = (Call)NHWCToNCHW(Tensor.FromScalar(1, [2, 2, 3, 4])); + var c1 = (Call)NHWCToNCHW(Tensor.FromScalar(1, [2, 2, 1, 1])); Assert.Equal(c0.Arguments[1].GetHashCode(), c1.Arguments[1].GetHashCode()); Expr pre = c0 + c1; diff --git a/src/Nncase.Tests/Rules/Neutral/UnitTestCombineBinary.cs b/src/Nncase.Tests/Rules/Neutral/UnitTestCombineBinary.cs index 264f962f2..6f54b225b 100644 --- a/src/Nncase.Tests/Rules/Neutral/UnitTestCombineBinary.cs +++ b/src/Nncase.Tests/Rules/Neutral/UnitTestCombineBinary.cs @@ -40,9 +40,9 @@ public static IEnumerable CombineClampBinaryPositiveData }; var clampShapes = new object[] { - Array.Empty(), - new[] { 32 }, - new[] { 24, 24, 32 }, + Array.Empty(), + new long[] { 32 }, + new long[] { 24, 24, 32 }, }; var mins = new object[] { @@ -58,8 +58,8 @@ public static IEnumerable CombineClampBinaryPositiveData { p[0], p[1], - Tensor.FromScalar((float)p[3], (int[])p[2]), // min - Tensor.FromScalar((float)p[4], (int[])p[2]), // max + Tensor.FromScalar((float)p[3], (long[])p[2]), // min + Tensor.FromScalar((float)p[4], (long[])p[2]), // max }); } } diff --git a/src/Nncase.Tests/Rules/Neutral/UnitTestCombineReshape.cs b/src/Nncase.Tests/Rules/Neutral/UnitTestCombineReshape.cs index 918b685d2..b9a092719 100644 --- a/src/Nncase.Tests/Rules/Neutral/UnitTestCombineReshape.cs +++ b/src/Nncase.Tests/Rules/Neutral/UnitTestCombineReshape.cs @@ -25,66 +25,66 @@ namespace Nncase.Tests.Rules.NeutralTest; [AutoSetupTestMethod(InitSession = true)] public class UnitTestCombineReshape : TransformTestBase { - public static readonly TheoryData CombineConstBinaryReshapePositiveData = new() - { - // BinaryOp binaryOp, int[] lShape, int[] rShape, int[] shape, bool leftConst - { BinaryOp.Add, new[] { 32, 1, 32, 64 }, new[] { 64 }, new[] { 1, 32, 32, 64 }, false }, - { BinaryOp.Sub, new[] { 1, 32, 32, 64 }, new[] { 1 }, new[] { 1, 1024, 1, 64 }, false }, - { BinaryOp.Div, new[] { 64 }, new[] { 32, 1, 32, 64, }, new[] { 1, 32, 32, 64 }, true }, - { BinaryOp.Mul, new[] { 1 }, new[] { 1, 32, 32, 64, }, new[] { 1, 1024, 64, 1 }, true }, - { BinaryOp.Sub, new[] { 1 }, new[] { 1, 32, 32, 64, }, new[] { 1, 1024, 64, 1 }, true }, + public static readonly TheoryData CombineConstBinaryReshapePositiveData = new() + { + // BinaryOp binaryOp, long[] lShape, long[] rShape, long[] shape, bool leftConst + { BinaryOp.Add, new long[] { 32, 1, 32, 64 }, new long[] { 64 }, new long[] { 1, 32, 32, 64 }, false }, + { BinaryOp.Sub, new long[] { 1, 32, 32, 64 }, new long[] { 1 }, new long[] { 1, 1024, 1, 64 }, false }, + { BinaryOp.Div, new long[] { 64 }, new long[] { 32, 1, 32, 64, }, new long[] { 1, 32, 32, 64 }, true }, + { BinaryOp.Mul, new long[] { 1 }, new long[] { 1, 32, 32, 64, }, new long[] { 1, 1024, 64, 1 }, true }, + { BinaryOp.Sub, new long[] { 1 }, new long[] { 1, 32, 32, 64, }, new long[] { 1, 1024, 64, 1 }, true }, }; - public static readonly TheoryData TestCombineReshapeTransposeNegativeData = + public static readonly TheoryData TestCombineReshapeTransposeNegativeData = new() { - { new[] { 1, 77, 1, 64 }, new[] { 2, 1, 3, 0 }, new[] { 77, 64, 1 } }, - { new[] { 1, 77, 12, 64 }, new[] { 1, 0, 2, 3 }, new[] { 1, 77, 768 } }, + { new long[] { 1, 77, 1, 64 }, new long[] { 2, 1, 3, 0 }, new long[] { 77, 64, 1 } }, + { new long[] { 1, 77, 12, 64 }, new long[] { 1, 0, 2, 3 }, new long[] { 1, 77, 768 } }, }; public static IEnumerable CombineBinaryReshapePositiveData => new[] { - new object[] { new[] { 5, 4 }, new[] { 5, 4 }, new[] { 1, 20 } }, - new object[] { new[] { 4, 4 }, new[] { 4, 4 }, new[] { 2, 8 } }, - new object[] { new[] { 4 }, new[] { 4 }, new[] { 4 } }, - new object[] { new[] { 1, 3, 4 }, new[] { 1, 3, 4 }, new[] { 1, 4, 3 } }, - new object[] { new[] { 1, 3, 2, 4 }, new[] { 1, 3, 2, 4 }, new[] { 1, 1, 6, 4 } }, + new object[] { new long[] { 5, 4 }, new long[] { 5, 4 }, new long[] { 1, 20 } }, + new object[] { new long[] { 4, 4 }, new long[] { 4, 4 }, new long[] { 2, 8 } }, + new object[] { new long[] { 4 }, new long[] { 4 }, new long[] { 4 } }, + new object[] { new long[] { 1, 3, 4 }, new long[] { 1, 3, 4 }, new long[] { 1, 4, 3 } }, + new object[] { new long[] { 1, 3, 2, 4 }, new long[] { 1, 3, 2, 4 }, new long[] { 1, 1, 6, 4 } }, }; public static IEnumerable CombineConstBinaryReshapeNegativeData => new[] { - new object[] { new[] { 1, 32, 32, 64, }, new[] { 32, 64 }, new[] { 1, 16, 64, 64 } }, + new object[] { new long[] { 1, 32, 32, 64, }, new long[] { 32, 64 }, new long[] { 1, 16, 64, 64 } }, }; public static IEnumerable CombineBinaryReshapeNegativeData => new[] { - new object[] { new[] { 5, 4 }, new[] { 4, 5 }, new[] { 1, 20 } }, + new object[] { new long[] { 5, 4 }, new long[] { 4, 5 }, new long[] { 1, 20 } }, }; public static IEnumerable TestCombineUnaryReshapePositiveData => new[] { - new object[] { UnaryOp.Exp, new[] { 1, 3, 4 }, new[] { 1, 4, 3 } }, - new object[] { UnaryOp.Sqrt, new[] { 1, 3, 4 }, new[] { 3, 4, 1 } }, - new object[] { UnaryOp.Log, new[] { 1, 3, 4, 5 }, new[] { 3, 1, 1, 20 } }, - new object[] { UnaryOp.Abs, new[] { 1, 3, 4, 5 }, new[] { 1, 12, 5, 1 } }, + new object[] { UnaryOp.Exp, new long[] { 1, 3, 4 }, new long[] { 1, 4, 3 } }, + new object[] { UnaryOp.Sqrt, new long[] { 1, 3, 4 }, new long[] { 3, 4, 1 } }, + new object[] { UnaryOp.Log, new long[] { 1, 3, 4, 5 }, new long[] { 3, 1, 1, 20 } }, + new object[] { UnaryOp.Abs, new long[] { 1, 3, 4, 5 }, new long[] { 1, 12, 5, 1 } }, }; public static IEnumerable TestCombineReshapePadPositiveData => new[] { - new object[] { new[] { 1, 3, 4 }, new[] { 1, 5, 8 }, new[] { 0, 0, 1, 1, 2, 2 } }, - new object[] { new[] { 1, 3, 4 }, new[] { 1, 4, 5, 7 }, new[] { 1, 2, 1, 1, 2, 1 } }, + new object[] { new long[] { 1, 3, 4 }, new long[] { 1, 5, 8 }, new long[] { 0, 0, 1, 1, 2, 2 } }, + new object[] { new long[] { 1, 3, 4 }, new long[] { 1, 4, 5, 7 }, new long[] { 1, 2, 1, 1, 2, 1 } }, }; public static IEnumerable TestCombineReshapePadNegativeData => new[] { - new object[] { new[] { 1, 3, 4 }, new[] { 5, 8 }, new[] { 0, 0, 1, 1, 2, 2 } }, - new object[] { new[] { 1, 3, 4 }, new[] { 1, 4, 1, 35 }, new[] { 1, 2, 1, 1, 2, 1 } }, + new object[] { new long[] { 1, 3, 4 }, new long[] { 5, 8 }, new long[] { 0, 0, 1, 1, 2, 2 } }, + new object[] { new long[] { 1, 3, 4 }, new long[] { 1, 4, 1, 35 }, new long[] { 1, 2, 1, 1, 2, 1 } }, }; public static TheoryData<(int Count, IR.Expr Act)> TestCombineActivationsReshapePositiveData => new() @@ -104,7 +104,7 @@ public class UnitTestCombineReshape : TransformTestBase [Theory] [MemberData(nameof(CombineConstBinaryReshapePositiveData))] - public void TestCombineConstBinaryReshapePositive(BinaryOp binaryOp, int[] lShape, int[] rShape, int[] shape, bool leftConst) + public void TestCombineConstBinaryReshapePositive(BinaryOp binaryOp, long[] lShape, long[] rShape, long[] shape, bool leftConst) { Expr lhs = leftConst ? lShape.Sum() == 1 ? 0.5f : Const.FromValue(Random.Normal(DataTypes.Float32, 0, 1, 3, lShape).Evaluate()) : new Var("lhs", new TensorType(DataTypes.Float32, lShape)); Expr rhs = leftConst ? new Var("b", new TensorType(DataTypes.Float32, rShape)) : @@ -140,7 +140,7 @@ public void TestCombineBinaryReshapePositive(int[] lShape, int[] rShape, int[] s [Theory] [MemberData(nameof(CombineBinaryReshapeNegativeData))] - public void TestCombineBinaryReshapeNegative(int[] lShape, int[] rShape, int[] shape) + public void TestCombineBinaryReshapeNegative(long[] lShape, long[] rShape, long[] shape) { var a = Random.Normal(DataTypes.Float32, 0, 1, 0, lShape); var b = Tensor.From(Random.Normal(DataTypes.Float32, 0, 1, 0, rShape).Evaluate().AsTensor().ToArray(), rShape); @@ -151,7 +151,7 @@ public void TestCombineBinaryReshapeNegative(int[] lShape, int[] rShape, int[] s [Theory] [MemberData(nameof(CombineConstBinaryReshapeNegativeData))] - public void TestCombineConstBinaryReshapeNegative(int[] lShape, int[] rShape, int[] shape) + public void TestCombineConstBinaryReshapeNegative(long[] lShape, long[] rShape, long[] shape) { var a = Random.Normal(DataTypes.Float32, 0, 1, 0, lShape); var b = Tensor.From(Random.Normal(DataTypes.Float32, 0, 1, 0, rShape).Evaluate().AsTensor().ToArray(), rShape); @@ -162,7 +162,7 @@ public void TestCombineConstBinaryReshapeNegative(int[] lShape, int[] rShape, in [Theory] [MemberData(nameof(TestCombineUnaryReshapePositiveData))] - public void TestCombineUnaryReshapePositive(UnaryOp opType, int[] inShape, int[] shape) + public void TestCombineUnaryReshapePositive(UnaryOp opType, long[] inShape, long[] shape) { var a = new Var(); var normal = new Dictionary(); @@ -192,7 +192,7 @@ public void TestCombineReshapePadPositive(int[] inShape, int[] shape, int[] pads var a = new Var("input", new TensorType(DataTypes.Float32, inShape)); var normal = new Dictionary(); normal.Add(a, Random.Normal(DataTypes.Float32, 0, 1, 0, inShape).Evaluate()); - var rootPre = Tensors.Reshape(NN.Pad(a, Tensor.From(pads, new[] { pads.Length / 2, 2 }), PadMode.Constant, 0f), shape); + var rootPre = Tensors.Reshape(NN.Pad(a, Tensor.From(pads, [pads.Length / 2, 2]), PadMode.Constant, 0f), shape); TestMatched(rootPre, normal); } @@ -201,7 +201,7 @@ public void TestCombineReshapePadPositive(int[] inShape, int[] shape, int[] pads public void TestCombineReshapePadNegative(int[] inShape, int[] shape, int[] pads) { var a = new Var("input", new TensorType(DataTypes.Float32, inShape)); - var rootPre = Tensors.Reshape(NN.Pad(a, Tensor.From(pads, new[] { pads.Length / 2, 2 }), PadMode.Constant, 0f), shape); + var rootPre = Tensors.Reshape(NN.Pad(a, Tensor.From(pads, [pads.Length / 2, 2]), PadMode.Constant, 0f), shape); TestNotMatch(rootPre); } @@ -220,7 +220,7 @@ public void TestCombineReshapeTransposePostive(int[] inShape, int[] perm, int[] [Theory] [MemberData(nameof(TestCombineReshapeTransposeNegativeData))] - public void TestCombineReshapeTransposeNegative(int[] inShape, int[] perm, int[] newshape) + public void TestCombineReshapeTransposeNegative(long[] inShape, long[] perm, long[] newshape) { var input = new Var("input", new TensorType(DataTypes.Float32, inShape)); var rootPre = Tensors.Reshape(Tensors.Transpose(input, perm), newshape); diff --git a/src/Nncase.Tests/Rules/Neutral/UnitTestCombineTranspose.cs b/src/Nncase.Tests/Rules/Neutral/UnitTestCombineTranspose.cs index dcde48a8b..30a6fe477 100644 --- a/src/Nncase.Tests/Rules/Neutral/UnitTestCombineTranspose.cs +++ b/src/Nncase.Tests/Rules/Neutral/UnitTestCombineTranspose.cs @@ -25,63 +25,63 @@ namespace Nncase.Tests.Rules.NeutralTest; [AutoSetupTestMethod(InitSession = true)] public class UnitTestCombineTranspose : TransformTestBase { - public static readonly TheoryData CombineTransposeConstBinaryPositiveData = new() + public static readonly TheoryData CombineTransposeConstBinaryPositiveData = new() { - // BinaryOp binaryOp, int[] lShape, int[] rShape, int[] perm, bool leftConst - { BinaryOp.Add, new[] { 1, 32, 32, 64, }, new[] { 64 }, new[] { 0, 3, 1, 2 }, false }, - { BinaryOp.Add, new[] { 1, 32, 32, 64, }, Array.Empty(), new[] { 0, 3, 1, 2 }, false }, - { BinaryOp.Sub, new[] { 1, 32, 32, 64, }, new[] { 32, 64 }, new[] { 0, 3, 1, 2 }, false }, - { BinaryOp.Mul, new[] { 1, 32, 32, 64, }, new[] { 1, 1, 1, 64 }, new[] { 0, 3, 1, 2 }, false }, - { BinaryOp.Div, new[] { 64 }, new[] { 1, 32, 32, 64, }, new[] { 0, 3, 1, 2 }, true }, - { BinaryOp.Div, Array.Empty(), new[] { 1, 32, 32, 64, }, new[] { 0, 3, 1, 2 }, true }, - { BinaryOp.Sub, new[] { 32, 64 }, new[] { 1, 32, 32, 64, }, new[] { 0, 3, 1, 2 }, true }, - { BinaryOp.Mul, new[] { 1, 1, 1, 64 }, new[] { 1, 32, 32, 64, }, new[] { 0, 3, 1, 2 }, true }, + // BinaryOp binaryOp, long[] lShape, long[] rShape, int[] perm, bool leftConst + { BinaryOp.Add, new long[] { 1, 32, 32, 64, }, new long[] { 64 }, new[] { 0, 3, 1, 2 }, false }, + { BinaryOp.Add, new long[] { 1, 32, 32, 64, }, Array.Empty(), new[] { 0, 3, 1, 2 }, false }, + { BinaryOp.Sub, new long[] { 1, 32, 32, 64, }, new long[] { 32, 64 }, new[] { 0, 3, 1, 2 }, false }, + { BinaryOp.Mul, new long[] { 1, 32, 32, 64, }, new long[] { 1, 1, 1, 64 }, new[] { 0, 3, 1, 2 }, false }, + { BinaryOp.Div, new long[] { 64 }, new long[] { 1, 32, 32, 64, }, new[] { 0, 3, 1, 2 }, true }, + { BinaryOp.Div, Array.Empty(), new long[] { 1, 32, 32, 64, }, new[] { 0, 3, 1, 2 }, true }, + { BinaryOp.Sub, new long[] { 32, 64 }, new long[] { 1, 32, 32, 64, }, new[] { 0, 3, 1, 2 }, true }, + { BinaryOp.Mul, new long[] { 1, 1, 1, 64 }, new long[] { 1, 32, 32, 64, }, new[] { 0, 3, 1, 2 }, true }, }; public static IEnumerable CombineBinaryTransposePositiveData => new[] { - new object[] { new[] { 5, 4 }, new[] { 5, 4 }, new[] { 1, 0 } }, - new object[] { new[] { 4, 4 }, new[] { 4, 4 }, new[] { 1, 0 } }, - new object[] { new[] { 4 }, new[] { 4 }, new[] { 0 } }, - new object[] { new[] { 1, 3, 4 }, new[] { 1, 3, 4 }, new[] { 0, 2, 1 } }, - new object[] { new[] { 1, 3, 2, 4 }, new[] { 1, 3, 2, 4 }, new[] { 0, 2, 3, 1 } }, + new object[] { new long[] { 5, 4 }, new long[] { 5, 4 }, new[] { 1, 0 } }, + new object[] { new long[] { 4, 4 }, new long[] { 4, 4 }, new[] { 1, 0 } }, + new object[] { new long[] { 4 }, new long[] { 4 }, new[] { 0 } }, + new object[] { new long[] { 1, 3, 4 }, new long[] { 1, 3, 4 }, new[] { 0, 2, 1 } }, + new object[] { new long[] { 1, 3, 2, 4 }, new long[] { 1, 3, 2, 4 }, new[] { 0, 2, 3, 1 } }, }; public static IEnumerable CombineConstBinaryTransposeNotMatchData => new[] { - new object[] { new[] { 1, 3, 2, 4 }, new[] { 2, 3 }, new[] { 0, 3, 2, 1 } }, - new object[] { new[] { 1, 3, 2, 4 }, new[] { 2, 4, 3 }, new[] { 0, 2, 3, 1 } }, + new object[] { new long[] { 1, 3, 2, 4 }, new long[] { 2, 3 }, new[] { 0, 3, 2, 1 } }, + new object[] { new long[] { 1, 3, 2, 4 }, new long[] { 2, 4, 3 }, new[] { 0, 2, 3, 1 } }, }; public static IEnumerable CombineRConstBinaryTransposePositiveData => new[] { - new object[] { new[] { 1, 3, 2, 4 }, new[] { 3 }, new[] { 0, 3, 2, 1 } }, - new object[] { new[] { 1, 3, 2, 4 }, new[] { 3 }, new[] { 0, 2, 3, 1 } }, + new object[] { new long[] { 1, 3, 2, 4 }, new long[] { 3 }, new[] { 0, 3, 2, 1 } }, + new object[] { new long[] { 1, 3, 2, 4 }, new long[] { 3 }, new[] { 0, 2, 3, 1 } }, }; public static IEnumerable CombineLConstBinaryTransposePositiveData => new[] { - new object[] { new[] { 3 }, new[] { 1, 3, 2, 4 }, new[] { 0, 3, 2, 1 } }, - new object[] { new[] { 3 }, new[] { 1, 3, 2, 4 }, new[] { 0, 2, 3, 1 } }, + new object[] { new long[] { 3 }, new long[] { 1, 3, 2, 4 }, new[] { 0, 3, 2, 1 } }, + new object[] { new long[] { 3 }, new long[] { 1, 3, 2, 4 }, new[] { 0, 2, 3, 1 } }, }; public static IEnumerable TestCombineTransposeConcatPositiveData => new[] { - new object[] { new[] { 4, 4 }, new[] { 1, 0 }, 1, 2 }, - new object[] { new[] { 1, 3, 4 }, new[] { 0, 2, 1 }, 1, 6 }, - new object[] { new[] { 1, 3, 2, 4 }, new[] { 0, 2, 3, 1 }, 2, 2 }, + new object[] { new long[] { 4, 4 }, new[] { 1, 0 }, 1, 2 }, + new object[] { new long[] { 1, 3, 4 }, new[] { 0, 2, 1 }, 1, 6 }, + new object[] { new long[] { 1, 3, 2, 4 }, new[] { 0, 2, 3, 1 }, 2, 2 }, }; public static IEnumerable TestCombineTransposeConcatNegativeData => new[] { - new object[] { new[] { 4, 4 }, new[] { new[] { 1, 0 }, new[] { 0, 1 } }, 1, 2, true }, - new object[] { new[] { 1, 3, 2, 4 }, new[] { new[] { 0, 2, 3, 1 }, new[] { 0, 2, 3, 1 } }, 2, 2, false }, + new object[] { new long[] { 4, 4 }, new[] { new[] { 1, 0 }, new[] { 0, 1 } }, 1, 2, true }, + new object[] { new long[] { 1, 3, 2, 4 }, new[] { new[] { 0, 2, 3, 1 }, new[] { 0, 2, 3, 1 } }, 2, 2, false }, }; public static IEnumerable TestCombineTransposePadPositiveData => @@ -89,7 +89,7 @@ public class UnitTestCombineTranspose : TransformTestBase { new object[] { - new[] { 1, 3, 1, 2 }, new[] { 0, 3, 1, 2 }, + new long[] { 1, 3, 1, 2 }, new long[] { 0, 3, 1, 2 }, new[,] { { 0, 0 }, @@ -100,7 +100,7 @@ public class UnitTestCombineTranspose : TransformTestBase }, new object[] { - new[] { 1, 2, 3, 4 }, new[] { 0, 2, 3, 1 }, + new long[] { 1, 2, 3, 4 }, new long[] { 0, 2, 3, 1 }, new[,] { { 4, 4 }, @@ -111,7 +111,7 @@ public class UnitTestCombineTranspose : TransformTestBase }, new object[] { - new[] { 1, 2, 3, 4 }, new[] { 0, 3, 1, 2 }, + new long[] { 1, 2, 3, 4 }, new long[] { 0, 3, 1, 2 }, new[,] { { 1, 1 }, @@ -122,7 +122,7 @@ public class UnitTestCombineTranspose : TransformTestBase }, new object[] { - new[] { 5, 2, 3, 4 }, new[] { 3, 0, 1, 2 }, + new long[] { 5, 2, 3, 4 }, new long[] { 3, 0, 1, 2 }, new[,] { { 2, 2 }, @@ -133,7 +133,7 @@ public class UnitTestCombineTranspose : TransformTestBase }, new object[] { - new[] { 1, 2, 3, 4 }, new[] { 0, 3, 1, 2 }, + new long[] { 1, 2, 3, 4 }, new long[] { 0, 3, 1, 2 }, new[,] { { 1, 1 }, @@ -147,17 +147,17 @@ public class UnitTestCombineTranspose : TransformTestBase public static IEnumerable TestCombineTransposeReducePositiveData => new[] { - new object[] { new[] { 1, 3, 4 }, new[] { 0, 2, 1 }, 1, 0, false }, - new object[] { new[] { 1, 3, 4, 5 }, new[] { 0, 2, 3, 1 }, 2, 1, true }, + new object[] { new long[] { 1, 3, 4 }, new[] { 0, 2, 1 }, 1, 0, false }, + new object[] { new long[] { 1, 3, 4, 5 }, new[] { 0, 2, 3, 1 }, 2, 1, true }, }; public static IEnumerable TestCombineTransposeUnaryPositiveData => new[] { - new object[] { UnaryOp.Exp, new[] { 1, 3, 4 }, new[] { 0, 2, 1 } }, - new object[] { UnaryOp.Sqrt, new[] { 1, 3, 4 }, new[] { 0, 2, 1 } }, - new object[] { UnaryOp.Log, new[] { 1, 3, 4, 5 }, new[] { 0, 2, 3, 1 } }, - new object[] { UnaryOp.Abs, new[] { 1, 3, 4, 5 }, new[] { 0, 2, 3, 1 } }, + new object[] { UnaryOp.Exp, new long[] { 1, 3, 4 }, new[] { 0, 2, 1 } }, + new object[] { UnaryOp.Sqrt, new long[] { 1, 3, 4 }, new[] { 0, 2, 1 } }, + new object[] { UnaryOp.Log, new long[] { 1, 3, 4, 5 }, new[] { 0, 2, 3, 1 } }, + new object[] { UnaryOp.Abs, new long[] { 1, 3, 4, 5 }, new[] { 0, 2, 3, 1 } }, }; [Theory] @@ -215,7 +215,7 @@ public void TestCombineTransposeConcatNegative(int[] inShape, int[][] perm, int [Theory] [MemberData(nameof(CombineTransposeConstBinaryPositiveData))] - public void TestCombineTransposeConstBinaryPositive(BinaryOp binaryOp, int[] lShape, int[] rShape, int[] perm, bool leftConst) + public void TestCombineTransposeConstBinaryPositive(BinaryOp binaryOp, long[] lShape, long[] rShape, int[] perm, bool leftConst) { Expr lhs = leftConst ? lShape.Length == 0 ? 0.5f : Const.FromValue(Random.Normal(DataTypes.Float32, 0, 1, 3, lShape).Evaluate()) : @@ -256,7 +256,7 @@ public void TestCombineBinaryTransposePositive(int[] lShape, int[] rShape, int[] [Theory] [MemberData(nameof(CombineConstBinaryTransposeNotMatchData))] - public void TestCombineConstTransposeNotMatch(int[] lShape, int[] rShape, int[] perm) + public void TestCombineConstTransposeNotMatch(long[] lShape, long[] rShape, int[] perm) { var a = Random.Normal(DataTypes.Float32, 0, 1, 0, lShape); var b = Tensor.From(Random.Normal(DataTypes.Float32, 0, 1, 0, rShape).Evaluate().AsTensor().ToArray(), rShape); @@ -268,7 +268,7 @@ public void TestCombineConstTransposeNotMatch(int[] lShape, int[] rShape, int[] [Theory] [MemberData(nameof(CombineRConstBinaryTransposePositiveData))] - public void TestCombineTransposeRConstBinaryPositive(int[] lShape, int[] rShape, int[] perm) + public void TestCombineTransposeRConstBinaryPositive(long[] lShape, long[] rShape, int[] perm) { var a = Random.Normal(DataTypes.Float32, 0, 1, 0, lShape); var b = Tensor.From(Random.Normal(DataTypes.Float32, 0, 1, 0, rShape).Evaluate().AsTensor().ToArray(), rShape); @@ -280,7 +280,7 @@ public void TestCombineTransposeRConstBinaryPositive(int[] lShape, int[] rShape, [Theory] [MemberData(nameof(CombineLConstBinaryTransposePositiveData))] - public void TestCombineLConstBinaryTransposePositive(int[] lShape, int[] rShape, int[] perm) + public void TestCombineLConstBinaryTransposePositive(long[] lShape, long[] rShape, int[] perm) { var a = Tensor.From(Random.Normal(DataTypes.Float32, 0, 1, 0, lShape).Evaluate().AsTensor().ToArray(), lShape); var b = Random.Normal(DataTypes.Float32, 0, 1, 0, rShape); @@ -292,7 +292,7 @@ public void TestCombineLConstBinaryTransposePositive(int[] lShape, int[] rShape, [Theory] [MemberData(nameof(CombineLConstBinaryTransposePositiveData))] - public void TestCombineLConstBinaryTransposeNotFloat(int[] lShape, int[] rShape, int[] perm) + public void TestCombineLConstBinaryTransposeNotFloat(long[] lShape, long[] rShape, int[] perm) { var a = Random.Normal(DataTypes.Int64, 0, 1, 0, lShape).Evaluate().AsTensor(); var b = Random.Normal(DataTypes.Int64, 0, 1, 0, rShape); @@ -304,7 +304,7 @@ public void TestCombineLConstBinaryTransposeNotFloat(int[] lShape, int[] rShape, [Theory] [MemberData(nameof(TestCombineTransposePadPositiveData))] - public void TestCombineTransposePadPositive(int[] inShape, int[] perm, int[,] paddings, PadMode padM, float padValue) + public void TestCombineTransposePadPositive(long[] inShape, int[] perm, int[,] paddings, PadMode padM, float padValue) { var a = new Var("input", new TensorType(DataTypes.Float32, inShape)); var normal = new Dictionary(); @@ -322,7 +322,7 @@ public void TestCombineTransposePadPositive(int[] inShape, int[] perm, int[,] pa [Theory] [MemberData(nameof(TestCombineTransposePadPositiveData))] - public void TestCombinePadTransposePositive(int[] inShape, int[] perm, int[,] paddings, PadMode padM, float padValue) + public void TestCombinePadTransposePositive(long[] inShape, int[] perm, int[,] paddings, PadMode padM, float padValue) { var a = new Var("input", new TensorType(DataTypes.Float32, inShape)); var normal = new Dictionary(); @@ -340,7 +340,7 @@ public void TestCombinePadTransposePositive(int[] inShape, int[] perm, int[,] pa [Theory] [MemberData(nameof(TestCombineTransposeReducePositiveData))] - public void TestCombineTransposeReducePositive(int[] inShape, int[] perm, int axis, int initValue, bool keepDims) + public void TestCombineTransposeReducePositive(long[] inShape, int[] perm, int axis, int initValue, bool keepDims) { var a = new Var(); var normal = new Dictionary(); @@ -351,7 +351,7 @@ public void TestCombineTransposeReducePositive(int[] inShape, int[] perm, int ax [Theory] [MemberData(nameof(TestCombineTransposeUnaryPositiveData))] - public void TestCombineTransposeUnaryPositive(UnaryOp opType, int[] inShape, int[] perm) + public void TestCombineTransposeUnaryPositive(UnaryOp opType, long[] inShape, int[] perm) { var a = new Var(); var normal = new Dictionary(); @@ -362,7 +362,7 @@ public void TestCombineTransposeUnaryPositive(UnaryOp opType, int[] inShape, int [Theory] [ClassData(typeof(CombineTransposeReshapePostiveData))] - public void TestCombineTransposeReshapePostive(int[] inShape, int[] newShape, int[] perm) + public void TestCombineTransposeReshapePostive(long[] inShape, long[] newShape, int[] perm) { var a = new Var(new TensorType(DataTypes.Float32, inShape)); var feed_dict = new Dictionary @@ -373,7 +373,7 @@ public void TestCombineTransposeReshapePostive(int[] inShape, int[] newShape, in TestMatched(rootPre, feed_dict); } - private sealed class CombineTransposeReshapePostiveData : TheoryData + private sealed class CombineTransposeReshapePostiveData : TheoryData { public CombineTransposeReshapePostiveData() { @@ -390,7 +390,7 @@ public CombineTransposeReshapePostiveData() foreach (var (a, b, c) in new[] { inshapes, newShapes, perms }.CartesianProduct().Select(i => i.ToArray()).Select(i => (i[0], i[1], i[2]))) { - Add(a, b, c); + Add(a.ToLongs(), b.ToLongs(), c); } } } diff --git a/src/Nncase.Tests/Rules/Neutral/UnitTestFlattenToReshape.cs b/src/Nncase.Tests/Rules/Neutral/UnitTestFlattenToReshape.cs index e0a88f473..e2c17ae39 100644 --- a/src/Nncase.Tests/Rules/Neutral/UnitTestFlattenToReshape.cs +++ b/src/Nncase.Tests/Rules/Neutral/UnitTestFlattenToReshape.cs @@ -37,7 +37,7 @@ public class UnitTestFlattenToReshape : TransformTestBase public static IEnumerable TestFlattenToReshapeNegativeData => new[] { - new object[] { new[] { 2, 4, IR.Dimension.Unknown }, 1 }, + new object[] { new[] { 2, 4, IR.Dimension.Unknown() }, 1 }, }; [Theory] diff --git a/src/Nncase.Tests/Rules/Neutral/UnitTestFoldLayerNorm.cs b/src/Nncase.Tests/Rules/Neutral/UnitTestFoldLayerNorm.cs index a91ca0c1a..2c03937a1 100644 --- a/src/Nncase.Tests/Rules/Neutral/UnitTestFoldLayerNorm.cs +++ b/src/Nncase.Tests/Rules/Neutral/UnitTestFoldLayerNorm.cs @@ -107,7 +107,7 @@ public void TestFoldLayerNormPositive2(int[] shape) var v2 = IR.F.Tensors.Reduce(ReduceOp.Mean, v0, axes, initValue, keepDims); var v3 = IR.F.Math.Binary(BinaryOp.Sub, v0, v2); var v4 = IR.F.Math.Binary(BinaryOp.Pow, v3, 2f); - var v5 = IR.F.Tensors.Reduce(ReduceOp.Mean, v4, Tensor.From(axes, new[] { 1 }), initValue, keepDims); + var v5 = IR.F.Tensors.Reduce(ReduceOp.Mean, v4, Tensor.From(axes, [1]), initValue, keepDims); var v6 = IR.F.Math.Binary(BinaryOp.Add, v5, 1e-05f); var v7 = IR.F.Math.Unary(UnaryOp.Sqrt, v6); var v8 = IR.F.Math.Binary(BinaryOp.Div, v3, v7); @@ -134,7 +134,7 @@ public void TestFoldLayerNormNegative2(int[] shape) var v2 = IR.F.Tensors.Reduce(ReduceOp.Mean, v0, axes, initValue, keepDims); var v3 = IR.F.Math.Binary(BinaryOp.Sub, v0, v2); var v4 = IR.F.Math.Binary(BinaryOp.Pow, v3, 2f); - var v5 = IR.F.Tensors.Reduce(ReduceOp.Mean, v4, Tensor.From(axes, new[] { 1 }), initValue, keepDims); + var v5 = IR.F.Tensors.Reduce(ReduceOp.Mean, v4, Tensor.From(axes, [1]), initValue, keepDims); var v6 = IR.F.Math.Binary(BinaryOp.Add, v5, 1e-05f); var v7 = IR.F.Math.Unary(UnaryOp.Sqrt, v6); var v8 = IR.F.Math.Binary(BinaryOp.Div, v4, v7); @@ -167,7 +167,7 @@ public void TestFoldLayerNormPositive3(int[] shape) var v3 = IR.F.Tensors.Reduce(ReduceOp.Mean, v0, axes, initValue, keepDims); var v4 = IR.F.Math.Binary(BinaryOp.Sub, v0, v3); var v5 = IR.F.Math.Unary(UnaryOp.Square, v4); - var v6 = IR.F.Tensors.Reduce(ReduceOp.Mean, v5, Tensor.From(axes, new[] { 1 }), initValue, keepDims); + var v6 = IR.F.Tensors.Reduce(ReduceOp.Mean, v5, Tensor.From(axes, [1]), initValue, keepDims); var v7 = IR.F.Math.Binary(BinaryOp.Add, v6, 1e-05f); var v8 = IR.F.Math.Unary(UnaryOp.Rsqrt, v7); var v9 = IR.F.Math.Binary(BinaryOp.Mul, v8, 0.5f); @@ -196,7 +196,7 @@ public void TestFoldLayerNormNegative3(int[] shape) var v3 = IR.F.Tensors.Reduce(ReduceOp.Mean, v0, axes, initValue, keepDims); var v4 = IR.F.Math.Binary(BinaryOp.Sub, v0, v3); var v5 = IR.F.Math.Unary(UnaryOp.Sqrt, v4); - var v6 = IR.F.Tensors.Reduce(ReduceOp.Mean, v5, Tensor.From(axes, new[] { 1 }), initValue, keepDims); + var v6 = IR.F.Tensors.Reduce(ReduceOp.Mean, v5, Tensor.From(axes, [1]), initValue, keepDims); var v7 = IR.F.Math.Binary(BinaryOp.Add, v6, 1e-05f); var v8 = IR.F.Math.Unary(UnaryOp.Rsqrt, v7); var v9 = IR.F.Math.Binary(BinaryOp.Mul, v8, 0.5f); @@ -228,10 +228,10 @@ public void TestFoldLayerNormPositive4(int[] shape) Expr rootPre; { var v1 = input; - var v2 = IR.F.Tensors.Reduce(ReduceOp.Mean, v1, Tensor.From(axes, new[] { 1 }), initValue, keepDims); + var v2 = IR.F.Tensors.Reduce(ReduceOp.Mean, v1, Tensor.From(axes, [1]), initValue, keepDims); var v3 = IR.F.Math.Binary(BinaryOp.Sub, v1, v2); var v4 = IR.F.Math.Binary(BinaryOp.Mul, v3, v3); - var v5 = IR.F.Tensors.Reduce(ReduceOp.Mean, v4, Tensor.From(axes, new[] { 1 }), initValue, keepDims); + var v5 = IR.F.Tensors.Reduce(ReduceOp.Mean, v4, Tensor.From(axes, [1]), initValue, keepDims); var v6 = IR.F.Math.Binary(BinaryOp.Add, v5, 1e-05f); var v7 = IR.F.Math.Unary(UnaryOp.Rsqrt, v6); var v8 = IR.F.Math.Binary(BinaryOp.Mul, v7, 0.05f); @@ -249,10 +249,10 @@ public void TestFoldLayerNormPositive4(int[] shape) var beta = IR.F.Random.Normal(DataTypes.Float32, 0, 1, 4, shape[^1]); { var v1 = input1; - var v2 = IR.F.Tensors.Reduce(ReduceOp.Mean, v1, Tensor.From(axes, new[] { 1 }), initValue, keepDims); + var v2 = IR.F.Tensors.Reduce(ReduceOp.Mean, v1, Tensor.From(axes, [1]), initValue, keepDims); var v3 = IR.F.Math.Binary(BinaryOp.Sub, v1, v2); var v4 = IR.F.Math.Binary(BinaryOp.Mul, v3, v3); - var v5 = IR.F.Tensors.Reduce(ReduceOp.Mean, v4, Tensor.From(axes, new[] { 1 }), initValue, keepDims); + var v5 = IR.F.Tensors.Reduce(ReduceOp.Mean, v4, Tensor.From(axes, [1]), initValue, keepDims); var v6 = IR.F.Math.Binary(BinaryOp.Add, v5, 1e-05f); var v7 = IR.F.Math.Unary(UnaryOp.Rsqrt, v6); var v8 = IR.F.Math.Binary(BinaryOp.Mul, v7, gamma.Evaluate().AsTensor()); @@ -281,7 +281,7 @@ public void TestFoldLayerNormNegative4(int[] shape) var v3 = IR.F.Tensors.Reduce(ReduceOp.Mean, v0, axes, initValue, keepDims); var v4 = IR.F.Math.Binary(BinaryOp.Sub, v0, v3); var v5 = IR.F.Math.Unary(UnaryOp.Sqrt, v4); - var v6 = IR.F.Tensors.Reduce(ReduceOp.Mean, v5, Tensor.From(axes, new[] { 1 }), initValue, keepDims); + var v6 = IR.F.Tensors.Reduce(ReduceOp.Mean, v5, Tensor.From(axes, [1]), initValue, keepDims); var v7 = IR.F.Math.Binary(BinaryOp.Add, v6, 1e-05f); var v8 = IR.F.Math.Unary(UnaryOp.Rsqrt, v7); var v9 = IR.F.Math.Binary(BinaryOp.Mul, v8, 0.5f); diff --git a/src/Nncase.Tests/Rules/Neutral/UnitTestReshapeBatchMatmul.cs b/src/Nncase.Tests/Rules/Neutral/UnitTestReshapeBatchMatmul.cs index d2fdea0d3..89ead9bf6 100644 --- a/src/Nncase.Tests/Rules/Neutral/UnitTestReshapeBatchMatmul.cs +++ b/src/Nncase.Tests/Rules/Neutral/UnitTestReshapeBatchMatmul.cs @@ -34,7 +34,7 @@ public class UnitTestReshapeBatchMatmul : TransformTestBase public static IEnumerable TestReshapeBatchMatmulNegativeData => new[] { - new object[] { new[] { 2, 1, 4 }, new[] { 4, Dimension.Unknown } }, + new object[] { new[] { 2, 1, 4 }, new[] { 4, Dimension.Unknown() } }, }; [Theory] diff --git a/src/Nncase.Tests/Rules/Neutral/UnitTestSpaceToBatchTransform.cs b/src/Nncase.Tests/Rules/Neutral/UnitTestSpaceToBatchTransform.cs index 563693c66..d02aaabba 100644 --- a/src/Nncase.Tests/Rules/Neutral/UnitTestSpaceToBatchTransform.cs +++ b/src/Nncase.Tests/Rules/Neutral/UnitTestSpaceToBatchTransform.cs @@ -34,7 +34,7 @@ public class UnitTestSpaceToBatchToPad : TransformTestBase new[] { new object[] { new[] { 1, 128, 128, new IR.Dimension(1) }, new[] { 2, 2 }, new[,] { { 0, 0 }, { 0, 0 } } }, - new object[] { new[] { 1, 128, 128, IR.Dimension.Unknown }, new[] { 1, 1 }, new[,] { { 1, 1 }, { 1, 1 } } }, + new object[] { new[] { 1, 128, 128, IR.Dimension.Unknown() }, new[] { 1, 1 }, new[,] { { 1, 1 }, { 1, 1 } } }, }; [Theory] diff --git a/src/Nncase.Tests/Rules/Neutral/UnitTestSqueezeToReshape.cs b/src/Nncase.Tests/Rules/Neutral/UnitTestSqueezeToReshape.cs index ece1de55f..f56c2a4d3 100644 --- a/src/Nncase.Tests/Rules/Neutral/UnitTestSqueezeToReshape.cs +++ b/src/Nncase.Tests/Rules/Neutral/UnitTestSqueezeToReshape.cs @@ -33,7 +33,7 @@ public class UnitTestSqueezeToReshape : TransformTestBase public static IEnumerable TestSqueezeToReshapeNegativeData => new[] { - new object[] { new[] { 2, 4, IR.Dimension.Unknown }, Array.Empty() }, + new object[] { new[] { 2, 4, IR.Dimension.Unknown() }, Array.Empty() }, }; [Theory] diff --git a/src/Nncase.Tests/Rules/Neutral/UnitTestUnSqueezeToReshape.cs b/src/Nncase.Tests/Rules/Neutral/UnitTestUnSqueezeToReshape.cs index f6b4f9c9b..0f6b9dd02 100644 --- a/src/Nncase.Tests/Rules/Neutral/UnitTestUnSqueezeToReshape.cs +++ b/src/Nncase.Tests/Rules/Neutral/UnitTestUnSqueezeToReshape.cs @@ -30,7 +30,7 @@ public class UnitTestUnSqueezeToReshape : TransformTestBase }; public static IEnumerable TestUnSqueezeToReshapeNegativeData => - new[] { new object[] { new[] { 2, 4, IR.Dimension.Unknown }, new[] { -1 } }, }; + new[] { new object[] { new[] { 2, 4, IR.Dimension.Unknown() }, new[] { -1 } }, }; [Theory] [MemberData(nameof(TestUnSqueezeToReshapePositiveData))] diff --git a/src/Nncase.Tests/Rules/Packing/PackUtilityTest.cs b/src/Nncase.Tests/Rules/Packing/PackUtilityTest.cs index 2552bb1a2..a005a846e 100644 --- a/src/Nncase.Tests/Rules/Packing/PackUtilityTest.cs +++ b/src/Nncase.Tests/Rules/Packing/PackUtilityTest.cs @@ -13,10 +13,10 @@ namespace Nncase.Tests.Rules.Packing; public sealed class PackUtilityTest { [Theory] - [InlineData(new object[] { new int[] { 1, 3, 2, 3, 1, 1, 7 }, new int[] { 1, 1, 3, 6, 1, 7 }, true })] - [InlineData(new object[] { new int[] { 2, 3, 4 }, new int[] { 4, 3, 2 }, false })] - [InlineData(new object[] { new int[] { 4, 4096 }, new int[] { 1, 4, 64, 64 }, true })] - public void TestComputeReshapeMapping(int[] inShape, int[] newShape, bool valid) + [InlineData(new object[] { new long[] { 1, 3, 2, 3, 1, 1, 7 }, new long[] { 1, 1, 3, 6, 1, 7 }, true })] + [InlineData(new object[] { new long[] { 2, 3, 4 }, new long[] { 4, 3, 2 }, false })] + [InlineData(new object[] { new long[] { 4, 4096 }, new long[] { 1, 4, 64, 64 }, true })] + public void TestComputeReshapeMapping(long[] inShape, long[] newShape, bool valid) { Assert.Equal(valid, PackUtility.TryGetShapeMapMatrix(inShape, newShape, out var mat)); if (valid) diff --git a/src/Nncase.Tests/Rules/ShapeBucket/ShapeBucketTest.cs b/src/Nncase.Tests/Rules/ShapeBucket/ShapeBucketTest.cs index cec7da22a..19f7fccef 100644 --- a/src/Nncase.Tests/Rules/ShapeBucket/ShapeBucketTest.cs +++ b/src/Nncase.Tests/Rules/ShapeBucket/ShapeBucketTest.cs @@ -64,7 +64,7 @@ public void TestBucketPad() [Fact] public async Task TestSingleVarFusionBucket() { - var mainVar = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown, 24, 24 })); + var mainVar = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown(), 24, 24 })); var dimVar = Scalar("dimVar"); CompileOptions.ShapeBucketOptions.Enable = true; CompileOptions.ShapeBucketOptions.SegmentsCount = 2; @@ -73,7 +73,7 @@ public async Task TestSingleVarFusionBucket() CompileOptions.ShapeBucketOptions.VarMap = new Dictionary { { mainVar, new Expr[] { 1, dimVar, 24, 24 } } }; var input = Testing.Rand(1, 3, 24, 24); - var fusionVar = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown, 24, 24 })); + var fusionVar = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown(), 24, 24 })); var f = new BucketFusion("MatMul_0", "stackvm", IR.F.Math.MatMul(fusionVar, fusionVar), new[] { fusionVar }, new[] { dimVar }); var main = new Function("main", new Call(f, mainVar), mainVar); var shape = new Dictionary(); @@ -87,7 +87,7 @@ public async Task TestSingleVarFusionBucket() [Fact] public async Task TestRebuild() { - var mainVar = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown, 24, 24 })); + var mainVar = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown(), 24, 24 })); var dimVar = Scalar("dimVar"); CompileOptions.ShapeBucketOptions.Enable = true; CompileOptions.ShapeBucketOptions.SegmentsCount = 2; @@ -96,7 +96,7 @@ public async Task TestRebuild() CompileOptions.ShapeBucketOptions.VarMap = new Dictionary { { mainVar, new Expr[] { 1, dimVar, 24, 24 } } }; var input = Testing.Rand(1, 3, 24, 24); - var fusionVar = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown, 24, 24 })); + var fusionVar = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown(), 24, 24 })); var shapeVar = new Var(new TensorType(DataTypes.Int64, new[] { 4 })); var body = IR.F.Math.MatMul(Reshape(fusionVar, shapeVar), fusionVar); var f = new BucketFusion("MatMul_0", "stackvm", body, new[] { fusionVar, shapeVar }, new[] { dimVar }); @@ -113,7 +113,7 @@ public async Task TestRebuild() [Fact] public async Task TestTupleOutput() { - var mainVar = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown, 24, 24 })); + var mainVar = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown(), 24, 24 })); var dimVar = Scalar("dimVar"); CompileOptions.ShapeBucketOptions.Enable = true; CompileOptions.ShapeBucketOptions.SegmentsCount = 2; @@ -122,7 +122,7 @@ public async Task TestTupleOutput() CompileOptions.ShapeBucketOptions.VarMap = new Dictionary { { mainVar, new Expr[] { 1, dimVar, 24, 24 } } }; var input = Testing.Rand(1, 3, 24, 24); - var fusionVar = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown, 24, 24 })); + var fusionVar = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown(), 24, 24 })); var mm = IR.F.Math.MatMul(fusionVar, fusionVar); var body = new IR.Tuple(mm, mm); var f = new BucketFusion("MatMul_0", "stackvm", body, new[] { fusionVar }, new[] { dimVar }); @@ -138,8 +138,8 @@ public async Task TestTupleOutput() [Fact] public async Task TestDoubleVarFusionBucket() { - var mainVarLhs = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown, 24, 24 })); - var mainVarRhs = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown, 24, 24 })); + var mainVarLhs = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown(), 24, 24 })); + var mainVarRhs = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown(), 24, 24 })); var dimVar1 = Scalar("dimVar1"); var dimVar2 = Scalar("dimVar2"); CompileOptions.ShapeBucketOptions.Enable = true; @@ -158,7 +158,7 @@ public async Task TestDoubleVarFusionBucket() var inputLhs = Testing.Rand(1, 3, 24, 24); var inputRhs = Testing.Rand(1, 3, 24, 24); - var fusionVar = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown, 24, 24 })); + var fusionVar = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown(), 24, 24 })); var f = new BucketFusion("MatMul_0", "stackvm", IR.F.Math.MatMul(fusionVar, fusionVar), new[] { fusionVar }, new[] { dimVar1 }); var main = new Function("main", new Call(f, mainVarLhs, mainVarRhs), mainVarLhs, mainVarRhs); var shape = new Dictionary(); @@ -176,8 +176,8 @@ public async Task TestDoubleVarFusionBucket() [Fact] public async Task TestDoubleVarWithMultiDimEffect() { - var mainVarLhs = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown, 24, 24 })); - var mainVarRhs = new Var(new TensorType(DataTypes.Float32, new[] { Dimension.Unknown, 1, 24, 24 })); + var mainVarLhs = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown(), 24, 24 })); + var mainVarRhs = new Var(new TensorType(DataTypes.Float32, new[] { Dimension.Unknown(), 1, 24, 24 })); var dimVar1 = Scalar("dimVar1"); var dimVar2 = Scalar("dimVar2"); CompileOptions.ShapeBucketOptions.Enable = true; @@ -196,7 +196,7 @@ public async Task TestDoubleVarWithMultiDimEffect() var inputLhs = Testing.Rand(1, 3, 24, 24); var inputRhs = Testing.Rand(3, 1, 24, 24); - var fusionVar = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown, 24, 24 })); + var fusionVar = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown(), 24, 24 })); var f = new BucketFusion("MatMul_0", "stackvm", IR.F.Math.MatMul(fusionVar, fusionVar), new[] { fusionVar }, new[] { dimVar1 }); var main = new Function("main", new Call(f, mainVarLhs, mainVarRhs), mainVarLhs, mainVarRhs); var shape = new Dictionary(); diff --git a/src/Nncase.Tests/Rules/ShapeExpr/UnitTestFoldGetItemShapeOf.cs b/src/Nncase.Tests/Rules/ShapeExpr/UnitTestFoldGetItemShapeOf.cs index bab7175f5..a043596f1 100644 --- a/src/Nncase.Tests/Rules/ShapeExpr/UnitTestFoldGetItemShapeOf.cs +++ b/src/Nncase.Tests/Rules/ShapeExpr/UnitTestFoldGetItemShapeOf.cs @@ -21,7 +21,7 @@ public class UnitTestFoldGetItemShapeOf : TransformTestBase [Fact] public void TestFoldGetItemShapeOf() { - var input = new Var(new TensorType(DataTypes.Float32, new[] { 1, 3, Dimension.Unknown, 24 })); + var input = new Var(new TensorType(DataTypes.Float32, new[] { 1, 3, Dimension.Unknown(), 24 })); var data = Testing.Rand(1, 3, 24, 24); var dict = new Dictionary { { input, Value.FromTensor(data) } }; TestMatched(ShapeOf(input)[1], dict); @@ -30,7 +30,7 @@ public void TestFoldGetItemShapeOf() [Fact] public void TestFoldGetItemShapeOfWithCast() { - var input = new Var(new TensorType(DataTypes.Float32, new[] { 1, 3, Dimension.Unknown, 24 })); + var input = new Var(new TensorType(DataTypes.Float32, new[] { 1, 3, Dimension.Unknown(), 24 })); var data = Testing.Rand(1, 3, 24, 24); var dict = new Dictionary { { input, Value.FromTensor(data) } }; TestMatched(Cast(ShapeOf(input), DataTypes.Int32)[1], dict); @@ -39,7 +39,7 @@ public void TestFoldGetItemShapeOfWithCast() [Fact] public void TestFoldGetItemShapeOfWithDynamic() { - var input = new Var(new TensorType(DataTypes.Int32, new[] { 1, 3, Dimension.Unknown, 24 })); + var input = new Var(new TensorType(DataTypes.Int32, new[] { 1, 3, Dimension.Unknown(), 24 })); TestNotMatch(ShapeOf(input)[2]); } } diff --git a/src/Nncase.Tests/Simulator/UnitTestInterop.cs b/src/Nncase.Tests/Simulator/UnitTestInterop.cs index 603709129..dcf3b1229 100644 --- a/src/Nncase.Tests/Simulator/UnitTestInterop.cs +++ b/src/Nncase.Tests/Simulator/UnitTestInterop.cs @@ -104,8 +104,8 @@ public void TestCreateTensorFromTensor() var dtype = RTDataType.FromTypeCode(Runtime.TypeCode.Float32); Assert.NotNull(rtTensor); Assert.Equal(dtype, rtTensor.ElementType); - Assert.Equal(MemoryMarshal.Cast(tensor.Dimensions).ToArray(), rtTensor.Dimensions.ToArray()); - Assert.Equal(MemoryMarshal.Cast(tensor.Strides).ToArray(), rtTensor.Strides.ToArray()); + Assert.Equal(MemoryMarshal.Cast(tensor.Dimensions.ToInts()).ToArray(), rtTensor.Dimensions.ToArray()); + Assert.Equal(MemoryMarshal.Cast(tensor.Strides.ToInts()).ToArray(), rtTensor.Strides.ToArray()); var buffer = rtTensor.Buffer.Buffer.AsHost()!; using (var mmOwner = buffer.Map(RTMapAccess.Read)) diff --git a/src/Nncase.Tests/TIR/PrimFunc/IDataFlowPrimFuncCase.cs b/src/Nncase.Tests/TIR/PrimFunc/IDataFlowPrimFuncCase.cs index 8e7196ba4..5b336137b 100644 --- a/src/Nncase.Tests/TIR/PrimFunc/IDataFlowPrimFuncCase.cs +++ b/src/Nncase.Tests/TIR/PrimFunc/IDataFlowPrimFuncCase.cs @@ -26,7 +26,7 @@ public interface IDataFlowPrimFuncCase internal static class PrimFuncBuilder { - public static readonly int[] Dimensions = new[] { 1, 4, 8, 9 }; + public static readonly long[] Dimensions = [1, 4, 8, 9]; private static int _count; diff --git a/src/Nncase.Tests/Targets/UnitTestCPUKernels.cs b/src/Nncase.Tests/Targets/UnitTestCPUKernels.cs index c8fd4859d..4a3387603 100644 --- a/src/Nncase.Tests/Targets/UnitTestCPUKernels.cs +++ b/src/Nncase.Tests/Targets/UnitTestCPUKernels.cs @@ -422,8 +422,8 @@ public async Task TestPackReshape(int[] inshape, int[] outshape, int packRank, i } [Theory] - [InlineData([new int[] { 2, 8, 16, 2 }, new int[] { 0, 2, 1, 3 }, 2, 0])] - public async Task TestTranspose(int[] shape, int[] perm, int rank, int number) + [InlineData([new long[] { 2, 8, 16, 2 }, new int[] { 0, 2, 1, 3 }, 2, 0])] + public async Task TestTranspose(long[] shape, int[] perm, int rank, int number) { var input = new Var("input", new TensorType(DataTypes.Float32, shape)); Expr pre; // f32[1,3,28,28] diff --git a/src/Nncase.Tests/Targets/UnitTestCPUTarget.cs b/src/Nncase.Tests/Targets/UnitTestCPUTarget.cs index 6c92c0e22..aeab930d1 100644 --- a/src/Nncase.Tests/Targets/UnitTestCPUTarget.cs +++ b/src/Nncase.Tests/Targets/UnitTestCPUTarget.cs @@ -162,7 +162,7 @@ public void TestTupleOrder() [MemberData(nameof(TestGetItemData))] public void TestGetItem(int[] index) { - var input = Tensor.From(new[] { 1, 2, 3, 4, 5, 6 }, new[] { 1, 2, 3 }); + var input = Tensor.From(new[] { 1, 2, 3, 4, 5, 6 }, [1, 2, 3]); var x = new Var("x", new TensorType(DataTypes.Int32, new[] { 1, 2, 3 })); var second = GetItem(x, index); var main = new Function("main", second, new[] { x });