diff --git a/CMakeLists.txt b/CMakeLists.txt index e8ec85d868..cbeac44cfb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -8,7 +8,6 @@ project(onnx-mlir) option(ONNX_MLIR_BUILD_TESTS "Build ONNX-MLIR test executables. If OFF, just generate build targets." ON) option(ONNX_MLIR_CCACHE_BUILD "Set to ON for a ccache enabled build." OFF) option(ONNX_MLIR_ENABLE_STABLEHLO "Enable StableHLO support." ON) -option(ONNX_MLIR_DECOMP_ONNX_CONVTRANSPOSE "Enable ONNXConvTransposeOp decomposition." ON) option(ONNX_MLIR_ENABLE_WERROR "Enable warnings as errors." OFF) option(ONNX_MLIR_SUPPRESS_THIRD_PARTY_WARNINGS "Suppress warning in third_party code." ON) option(ONNX_MLIR_ENABLE_JAVA "Set to ON for building the Java runtime, tools, and tests" ON) @@ -223,12 +222,6 @@ if (ONNX_MLIR_ENABLE_STABLEHLO) add_compile_definitions(ONNX_MLIR_ENABLE_STABLEHLO) endif() -if (ONNX_MLIR_DECOMP_ONNX_CONVTRANSPOSE) - add_compile_definitions(ONNX_MLIR_DECOMP_ONNX_CONVTRANSPOSE) - set(ONNX_MLIR_DECOMP_ONNX_CONVTRANSPOSE_ENABLED 1) -else() - set(ONNX_MLIR_DECOMP_ONNX_CONVTRANSPOSE_ENABLED 0) -endif() add_subdirectory(utils) add_subdirectory(include) diff --git a/src/Compiler/CompilerOptions.cpp b/src/Compiler/CompilerOptions.cpp index 79faccb733..de3c5f1fe8 100644 --- a/src/Compiler/CompilerOptions.cpp +++ b/src/Compiler/CompilerOptions.cpp @@ -78,6 +78,7 @@ bool enableParallel; // onnx-mlir only bool disableSimdOption; // onnx-mlir only bool enableFastMathOption; // onnx-mlir only bool disableRecomposeOption; // onnx-mlir only +bool disableConvTransposeDecomposeOption; // onnx-mlir only bool enableSimdDataLayout; // onnx-mlir only bool verifyInputTensors; // onnx-mlir only bool allowSorting; // onnx-mlir only @@ -257,6 +258,12 @@ static llvm::cl::opt disableRecomposeOptionOpt("disable-recompose", llvm::cl::location(disableRecomposeOption), llvm::cl::init(false), llvm::cl::cat(OnnxMlirOptions)); +static llvm::cl::opt disableConvTranposeDecomposeOptionOpt( + "disable-convtranspose-decompose", + llvm::cl::desc("Disable decomposition of ONNX ConvTranspose operator."), + llvm::cl::location(disableConvTransposeDecomposeOption), + llvm::cl::init(false), llvm::cl::cat(OnnxMlirCommonOptions)); + // Options for onnx-mlir only static llvm::cl::opt emissionTargetOpt( llvm::cl::desc("Choose target to emit:"), diff --git a/src/Compiler/CompilerOptions.hpp b/src/Compiler/CompilerOptions.hpp index 0d45dcfb1e..69f8d0246f 100644 --- a/src/Compiler/CompilerOptions.hpp +++ b/src/Compiler/CompilerOptions.hpp @@ -123,6 +123,7 @@ extern bool enableParallel; // onnx-mlir only extern bool disableSimdOption; // onnx-mlir only extern bool enableFastMathOption; // onnx-mlir only extern bool disableRecomposeOption; // onnx-mlir only +extern bool disableConvTransposeDecomposeOption; // onnx-mlir only extern bool enableSimdDataLayout; // onnx-mlir only extern bool verifyInputTensors; // onnx-mlir only extern bool allowSorting; // onnx-mlir only diff --git a/src/Conversion/ONNXToTOSA/DialectBuilder.cpp b/src/Conversion/ONNXToTOSA/DialectBuilder.cpp index d0caaf4161..a52c54053c 100644 --- a/src/Conversion/ONNXToTOSA/DialectBuilder.cpp +++ b/src/Conversion/ONNXToTOSA/DialectBuilder.cpp @@ -57,6 +57,12 @@ Value TosaBuilder::createConst( } bool TosaBuilder::needsRankBroadcast(ValueRange valueRange) { + if (llvm::any_of(valueRange, [](const auto value) { + return !mlir::cast(value.getType()).hasRank(); + })) { + return false; // we have no way to determine the broadcast, so do not + // attempt it + } int64_t firstRank = mlir::cast(valueRange[0].getType()).getRank(); for (Value operand : valueRange) { auto operandType = mlir::cast(operand.getType()); @@ -129,9 +135,8 @@ Value TosaBuilder::getConst(ArrayRef vec, ArrayRef shape) { return constOp; } -Value TosaBuilder::getSplattedConst( - float val, Type dtype, llvm::ArrayRef shape) { - auto constType = tosa::reduceAxisToOne(shape, rewriter().getF32Type()); +Value TosaBuilder::getSplattedConst(float val, Type dtype, int64_t rank) { + auto constType = tosa::reduceAxisToOne(rank, rewriter().getF32Type()); auto constAttr = DenseElementsAttr::get(constType, val); auto constOp = @@ -150,8 +155,7 @@ Value TosaBuilder::transpose(Value &value, llvm::ArrayRef perm) { auto valueType = mlir::cast(value.getType()); // get new value type Type newValueType = RankedTensorType::get( - llvm::SmallVector( - valueType.getShape().size(), ShapedType::kDynamic), + llvm::SmallVector(perm.size(), ShapedType::kDynamic), valueType.getElementType()); // create transpose for value Value newValue = tosa::CreateOpAndInfer( @@ -195,9 +199,12 @@ Value TosaBuilder::mul(Value &lhs, Value &rhs, int32_t shift) { rhs = valueVec[1]; } auto lhsType = mlir::cast(lhs.getType()); - Type newValueType = RankedTensorType::get( - llvm::SmallVector(lhsType.getRank(), ShapedType::kDynamic), - lhsType.getElementType()); + Type newValueType = + (!lhsType.hasRank()) + ? lhsType + : RankedTensorType::get(llvm::SmallVector( + lhsType.getRank(), ShapedType::kDynamic), + lhsType.getElementType()); return tosa::CreateOpAndInfer( rewriter(), loc(), newValueType, lhs, rhs, shift); } @@ -215,9 +222,12 @@ Value TosaBuilder::intdiv(Value &lhs, Value &rhs) { } auto lhsType = mlir::cast(lhs.getType()); - Type newValueType = RankedTensorType::get( - llvm::SmallVector(lhsType.getRank(), ShapedType::kDynamic), - lhsElementType); + Type newValueType = + (!lhsType.hasRank()) + ? lhsType + : RankedTensorType::get(llvm::SmallVector( + lhsType.getRank(), ShapedType::kDynamic), + lhsElementType); return tosa::CreateOpAndInfer( rewriter(), loc(), newValueType, lhs, rhs); } @@ -230,9 +240,12 @@ Value TosaBuilder::binaryOp(Value &lhs, Value &rhs) { rhs = valueVec[1]; } auto lhsType = mlir::cast(lhs.getType()); - Type newValueType = RankedTensorType::get( - llvm::SmallVector(lhsType.getRank(), ShapedType::kDynamic), - lhsType.getElementType()); + Type newValueType = + (!lhsType.hasRank()) + ? lhsType + : RankedTensorType::get(llvm::SmallVector( + lhsType.getRank(), ShapedType::kDynamic), + lhsType.getElementType()); return tosa::CreateOpAndInfer(rewriter(), loc(), newValueType, lhs, rhs); } @@ -246,11 +259,7 @@ template Value TosaBuilder::binaryOp( template Value TosaBuilder::unaryOp(mlir::Value &input) { - auto inputType = cast(input.getType()); - Type newValueType = RankedTensorType::get( - llvm::SmallVector(inputType.getRank(), ShapedType::kDynamic), - inputType.getElementType()); - return tosa::CreateOpAndInfer(rewriter(), loc(), newValueType, input); + return tosa::CreateOpAndInfer(rewriter(), loc(), input.getType(), input); } template Value TosaBuilder::unaryOp(mlir::Value &input); @@ -305,9 +314,12 @@ Value TosaBuilder::select( rhs = valueVec[2]; } auto lhsType = cast(lhs.getType()); - Type newValueType = RankedTensorType::get( - llvm::SmallVector(lhsType.getRank(), ShapedType::kDynamic), - lhsType.getElementType()); + Type newValueType = + (!lhsType.hasRank()) + ? lhsType + : RankedTensorType::get(llvm::SmallVector( + lhsType.getRank(), ShapedType::kDynamic), + lhsType.getElementType()); return tosa::CreateOpAndInfer( rewriter(), loc(), newValueType, cond, lhs, rhs); } @@ -328,7 +340,7 @@ mlir::Value TosaBuilder::castToNewTensorElementType( Value TosaBuilder::sqrt(mlir::Value &input) { auto inputType = cast(input.getType()); auto oneHalf = this->getSplattedConst( - 0.5, inputType.getElementType(), inputType.getShape()); + 0.5, inputType.getElementType(), inputType.getRank()); return this->binaryOp(input, oneHalf); } diff --git a/src/Conversion/ONNXToTOSA/DialectBuilder.hpp b/src/Conversion/ONNXToTOSA/DialectBuilder.hpp index b97c912692..333f90c456 100644 --- a/src/Conversion/ONNXToTOSA/DialectBuilder.hpp +++ b/src/Conversion/ONNXToTOSA/DialectBuilder.hpp @@ -91,8 +91,7 @@ struct TosaBuilder : DialectBuilder { // The tensor will have the same rank as shape but all dimensions will // have size 1 (differs from tensorflow impl.) // If dtype is provided, it also cast the value to the appropriate dtype. - mlir::Value getSplattedConst( - float val, mlir::Type dtype, llvm::ArrayRef shape = {}); + mlir::Value getSplattedConst(float val, mlir::Type dtype, int64_t rank); // Creates a constant of shape <1x1x...x1> of rank `rank` with all values set // to `value`. diff --git a/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp b/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp index eeee6ab7f0..9e259174ca 100644 --- a/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp +++ b/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp @@ -108,6 +108,14 @@ LogicalResult checkBasicTosaRequirementsForBinaryOps( return success(); } +namespace { +template +void copySingleResultType(OnnxOp opToCopyFrom, Value &valueToCopyTo) { + assert(opToCopyFrom->getNumResults() == 1); + valueToCopyTo.setType(opToCopyFrom->getResult(0).getType()); +} +} // namespace + // Element-wise unary ops lowering to TOSA dialect. //===----------------------------------------------------------------------===// template { TosaBuilder tosaBuilder(rewriter, op->getLoc()); Value mulOp = tosaBuilder.mul(lhs, rhs); + copySingleResultType(op, mulOp); rewriter.replaceOp(op, {mulOp}); return success(); @@ -230,12 +239,12 @@ static LogicalResult legalizeFloatingPointPrelu(Operation *op, auto loc = op->getLoc(); TosaBuilder tosaBuilder(rewriter, loc); Value constZero = tosaBuilder.getSplattedConst( - 0.0, outputType.getElementType(), outputType.getShape()); + 0.0, outputType.getElementType(), outputType.getRank()); auto mul = tosaBuilder.mul(input, alphaOrSlope); auto greaterEqual = tosaBuilder.greaterEqual(input, constZero); auto select = tosaBuilder.select(greaterEqual, input, mul); - + copySingleResultType(op, select); rewriter.replaceOp(op, {select}); return success(); } @@ -265,7 +274,7 @@ class ONNXLeakyReluOpLoweringToTOSA TosaBuilder tosaBuilder(rewriter, loc); return legalizeFloatingPointPrelu(op, rewriter, adaptor.getX(), tosaBuilder.getSplattedConst( - alpha, outputType.getElementType(), outputType.getShape()), + alpha, outputType.getElementType(), outputType.getRank()), outputType); } }; @@ -303,6 +312,7 @@ class ONNXComparisonOpLoweringToTOSA : public OpConversionPattern { } else if constexpr (std::is_same_v) { res = tosaBuilder.less(input1, input2); } + copySingleResultType(op, res); rewriter.replaceOp(op, {res}); return success(); } @@ -384,7 +394,7 @@ class ONNXCastOpLoweringToTOSA : public OpConversionPattern { // onnx.Cast and tosa.cast. if (resultTy.getElementType().getIntOrFloatBitWidth() != 1) { auto zero = tosaBuilder.getSplattedConst( - 0.0f, inputTy.getElementType(), resultTy.getShape()); + 0.0f, inputTy.getElementType(), resultTy.getRank()); auto positive = tosaBuilder.greaterEqual(input, zero); auto floor = tosaBuilder.unaryOp(input); @@ -412,6 +422,7 @@ class ONNXDivOpLoweringToTOSA : public OpConversionPattern { if (isa(resultElementType)) { Value divOp = tosaBuilder.intdiv(lhs, rhs); + copySingleResultType(op, divOp); rewriter.replaceOp(op, {divOp}); return success(); } @@ -419,6 +430,7 @@ class ONNXDivOpLoweringToTOSA : public OpConversionPattern { // tosa::ReciprocalOp and tosa::MulOp. Value reciprocalOp = tosaBuilder.unaryOp(rhs); Value mulOp = tosaBuilder.mul(lhs, reciprocalOp); + copySingleResultType(op, mulOp); rewriter.replaceOp(op, {mulOp}); return success(); } @@ -463,20 +475,21 @@ class ONNXEluOpLoweringToTOSA : public OpConversionPattern { TosaBuilder tosaBuilder(rewriter, op->getLoc()); Value one = tosaBuilder.getSplattedConst( - 1.0, resultTensorType.getElementType(), resultTensorType.getShape()); + 1.0, resultTensorType.getElementType(), resultTensorType.getRank()); Value alpha = tosaBuilder.getSplattedConst(adaptor.getAlpha().convertToDouble(), - resultTensorType.getElementType(), resultTensorType.getShape()); + resultTensorType.getElementType(), resultTensorType.getRank()); Value constZero = tosaBuilder.getSplattedConst( - 0.0, resultTensorType.getElementType(), resultTensorType.getShape()); + 0.0, resultTensorType.getElementType(), resultTensorType.getRank()); Value exp = tosaBuilder.unaryOp(input); + copySingleResultType(op, exp); Value expMinusOne = tosaBuilder.binaryOp(exp, one); Value alphaTimesExpMinusOne = tosaBuilder.mul(expMinusOne, alpha); Value greaterEqual = tosaBuilder.greaterEqual(input, constZero); auto select = tosaBuilder.select(greaterEqual, input, alphaTimesExpMinusOne); - + copySingleResultType(op, select); rewriter.replaceOp(op, {select}); return success(); } @@ -507,11 +520,16 @@ class ONNXHardSigmoidOpLoweringToTOSA APFloat oneOverAlpha(alpha.getSemantics(), 1); oneOverAlpha.divide(alpha, APFloat::rmNearestTiesToEven); + if (!resultType.hasRank()) { + return rewriter.notifyMatchFailure( + op, "HardSigmoid: Static shape required to create splatted const"); + } + Value constBetaOverAlpha = tosaBuilder.getSplattedConst(betaOverAlpha.convertToDouble(), - resultElementType, resultType.getShape()); + resultElementType, resultType.getRank()); Value constAlpha = tosaBuilder.getSplattedConst( - alpha.convertToDouble(), resultElementType, resultType.getShape()); + alpha.convertToDouble(), resultElementType, resultType.getRank()); auto addOp = tosaBuilder.binaryOp(input, constBetaOverAlpha); @@ -521,7 +539,7 @@ class ONNXHardSigmoidOpLoweringToTOSA rewriter.getF32FloatAttr(0), rewriter.getF32FloatAttr(oneOverAlpha.convertToDouble())); auto mulOp = tosaBuilder.mul(clampOp, constAlpha); - + copySingleResultType(op, mulOp); rewriter.replaceOp(op, {mulOp}); return success(); } @@ -556,14 +574,19 @@ class ONNXSoftplusOpLoweringToTOSA if (failed(IsFloat::checkType(rewriter, outputType.getElementType(), op))) { return failure(); } + if (!outputType.hasRank()) { + return rewriter.notifyMatchFailure( + op, "ONNXSoftplusOp: Rank required to create splatted const"); + } Value input = adaptor.getX(); TosaBuilder tosaBuilder(rewriter, op->getLoc()); auto one = tosaBuilder.getSplattedConst( - 1.0, outputType.getElementType(), outputType.getShape()); + 1.0, outputType.getElementType(), outputType.getRank()); auto expOp = tosaBuilder.unaryOp(input); + copySingleResultType(op, expOp); auto expPlusOne = tosaBuilder.binaryOp(expOp, one); auto logOp = tosaBuilder.unaryOp(expPlusOne); rewriter.replaceOp(op, {logOp}); @@ -585,15 +608,19 @@ class ONNXSeluOpLoweringToTOSA : public OpConversionPattern { Value input = adaptor.getX(); TosaBuilder tosaBuilder(rewriter, op->getLoc()); + if (!outputType.hasRank()) { + return rewriter.notifyMatchFailure( + op, "ONNXSeluOp: Rank required to create splatted const"); + } Value alpha = tosaBuilder.getSplattedConst(adaptor.getAlpha().convertToDouble(), - outputType.getElementType(), outputType.getShape()); + outputType.getElementType(), outputType.getRank()); Value gamma = tosaBuilder.getSplattedConst(adaptor.getGamma().convertToDouble(), - outputType.getElementType(), outputType.getShape()); + outputType.getElementType(), outputType.getRank()); Value constZero = tosaBuilder.getSplattedConst( - 0.0, outputType.getElementType(), outputType.getShape()); + 0.0, outputType.getElementType(), outputType.getRank()); Value exp = tosaBuilder.unaryOp(input); Value expTimesAlpha = tosaBuilder.mul(exp, alpha); @@ -621,15 +648,19 @@ class ONNXThresholdedReluOpLoweringToTOSA rewriter, outputType.getElementType(), op))) { return failure(); } + if (!outputType.hasRank()) { + return rewriter.notifyMatchFailure( + op, "ONNXThresholdedReluOp: Rank required to create splatted const"); + } Value input = adaptor.getX(); TosaBuilder tosaBuilder(rewriter, op->getLoc()); auto alpha = tosaBuilder.getSplattedConst(adaptor.getAlpha().convertToDouble(), - outputType.getElementType(), outputType.getShape()); + outputType.getElementType(), outputType.getRank()); auto zero = tosaBuilder.getSplattedConst( - 0.0, outputType.getElementType(), outputType.getShape()); + 0.0, outputType.getElementType(), outputType.getRank()); auto greater = tosaBuilder.greater(input, alpha); auto select = tosaBuilder.select(greater, input, zero); diff --git a/src/Conversion/ONNXToTOSA/Math/Gemm.cpp b/src/Conversion/ONNXToTOSA/Math/Gemm.cpp index dbc8b30411..1d1182d03b 100644 --- a/src/Conversion/ONNXToTOSA/Math/Gemm.cpp +++ b/src/Conversion/ONNXToTOSA/Math/Gemm.cpp @@ -49,6 +49,10 @@ class ONNXGemmOpLoweringToTOSA : public OpConversionPattern { FloatAttr beta = adaptor.getBetaAttr(); auto AType = mlir::cast(A.getType()); auto BType = mlir::cast(B.getType()); + if (!AType.hasRank() || !BType.hasRank()) { + return rewriter.notifyMatchFailure( + op, "Lowering Gemm to MatMul requires ranked A and B."); + } auto shapeA = AType.getShape(); auto shapeB = BType.getShape(); auto resultType = mlir::cast( @@ -103,7 +107,7 @@ class ONNXGemmOpLoweringToTOSA : public OpConversionPattern { if (alpha && alpha.getValueAsDouble() != 1.) { Value splattedConstAlpha = tosaBuilder.getSplattedConst( static_cast(alpha.getValueAsDouble()), AType.getElementType(), - newShapeA); + newShapeA.size()); alphaMulResult = tosaBuilder.mul(splattedConstAlpha, A, 0); } @@ -112,7 +116,7 @@ class ONNXGemmOpLoweringToTOSA : public OpConversionPattern { if (beta && isCPresent && beta.getValueAsDouble() != 1.) { Value splattedConstBeta = tosaBuilder.getSplattedConst( static_cast(beta.getValueAsDouble()), AType.getElementType(), - newShapeA); + newShapeA.size()); betaMulResult = tosaBuilder.mul(splattedConstBeta, C, 0); } diff --git a/src/Conversion/ONNXToTOSA/Math/Reduce.cpp b/src/Conversion/ONNXToTOSA/Math/Reduce.cpp index bb567b771d..72ce004df9 100644 --- a/src/Conversion/ONNXToTOSA/Math/Reduce.cpp +++ b/src/Conversion/ONNXToTOSA/Math/Reduce.cpp @@ -179,7 +179,7 @@ LogicalResult reduceMeanLowering(ONNXReduceMeanOp op, TosaBuilder tosaBuilder(rewriter, op->getLoc()); Value divConst = tosaBuilder.getSplattedConst( - divScale, outputType.getElementType(), outputType.getShape()); + divScale, outputType.getElementType(), outputType.getRank()); auto output = tosaBuilder.mul(val, divConst); if (!output) { diff --git a/src/Conversion/ONNXToTOSA/NN/AveragePool.cpp b/src/Conversion/ONNXToTOSA/NN/AveragePool.cpp index 73ff2f898c..80f56bfb55 100644 --- a/src/Conversion/ONNXToTOSA/NN/AveragePool.cpp +++ b/src/Conversion/ONNXToTOSA/NN/AveragePool.cpp @@ -56,7 +56,7 @@ LogicalResult handleIncludePadAttr( Value padding = tosa::buildOnnxToTosaPaddingConstOp( rewriter, pads, loc, {0, 0, 0, 0}, {}); auto constTosaTensor = - tosaBuilder.getSplattedConst(0.0, inputType.getElementType()); + tosaBuilder.getSplattedConst(0.0, inputType.getElementType(), 0); auto padOp = tosa::CreateOpAndInfer(rewriter, loc, mlir::RankedTensorType::get( diff --git a/src/Conversion/ONNXToTOSA/NN/BatchNorm.cpp b/src/Conversion/ONNXToTOSA/NN/BatchNorm.cpp index 396637d4cc..3cd31c4ebf 100644 --- a/src/Conversion/ONNXToTOSA/NN/BatchNorm.cpp +++ b/src/Conversion/ONNXToTOSA/NN/BatchNorm.cpp @@ -29,6 +29,11 @@ class ONNXBatchNormalizationInferenceModeOpLoweringToTOSA OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto outType = getTypeConverter()->convertType(op.getResult().getType()); + if (!cast(outType).hasRank()) { + return rewriter.notifyMatchFailure(op, + "ONNXBatchNormalizationInferenceModeOp to " + "TOSA requires a ranked result type"); + } auto outTensorType = cast(outType); // The layout of the output is N x C x D1 x D2 … Dn. For batch @@ -60,7 +65,7 @@ class ONNXBatchNormalizationInferenceModeOpLoweringToTOSA // epsilon's shape: constant -> {1, 1, 1, ...} newShape[1] = 1; auto eps = tosaBuilder.getSplattedConst(op.getEpsilon().convertToFloat(), - outTensorType.getElementType(), newShape); + outTensorType.getElementType(), newShape.size()); // output = (input - mean) * scale * rsqrt(var + eps) + bias auto op1SubInputMean = diff --git a/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.cpp b/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.cpp index 1d87cc3dca..fd4c29393f 100644 --- a/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.cpp +++ b/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.cpp @@ -418,7 +418,7 @@ std::optional convertReduceMeanOp(PatternRewriter &rewriter, if (!input_is_qtype) { Value div_const = tosaBuilder.getSplattedConst( - div_scale, output_type.getElementType(), output_type.getShape()); + div_scale, output_type.getElementType(), output_type.getRank()); return tosaBuilder.mul(val.value(), div_const); } diff --git a/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.cpp b/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.cpp index 3a5177234c..24ca82fcac 100644 --- a/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.cpp +++ b/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.cpp @@ -56,9 +56,9 @@ llvm::SmallVector createInt64VectorFromIndexExpr( } mlir::RankedTensorType reduceAxisToOne( - llvm::ArrayRef shape, Type elementType, Attribute encoding) { + int64_t rank, Type elementType, Attribute encoding) { return mlir::RankedTensorType::get( - llvm::SmallVector(shape.size(), 1), elementType, encoding); + llvm::SmallVector(rank, 1), elementType, encoding); } mlir::ElementsAttr getElementsAttrFromConst(mlir::Value &val) { diff --git a/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp b/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp index 8f0896c9c3..5c99af6d5d 100644 --- a/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp +++ b/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp @@ -39,9 +39,9 @@ int64_t convertNegativeAxis(int64_t axis, int64_t inputRank); llvm::SmallVector createInt64VectorFromIndexExpr( llvm::ArrayRef indexVector); -// Create a RankedTensorType with shape and all elements being 1 -mlir::RankedTensorType reduceAxisToOne(llvm::ArrayRef shape, - mlir::Type elementType, mlir::Attribute encoding = {}); +// Create a RankedTensorType with the given rank and all dims being 1 +mlir::RankedTensorType reduceAxisToOne( + int64_t rank, mlir::Type elementType, mlir::Attribute encoding = {}); // Returns the value TOSA ConstOp template diff --git a/src/Conversion/ONNXToTOSA/Tensor/Gather.cpp b/src/Conversion/ONNXToTOSA/Tensor/Gather.cpp index c34525872a..e9e74e838e 100644 --- a/src/Conversion/ONNXToTOSA/Tensor/Gather.cpp +++ b/src/Conversion/ONNXToTOSA/Tensor/Gather.cpp @@ -85,12 +85,12 @@ class ONNXGatherLoweringToTOSA : public OpConversionPattern { // Create an 1x..x1 constant containing the size of the gathered dimension. auto dimSize = tosaBuilder.getSplattedConst( - inputShape[axis], indicesType.getElementType(), indicesType.getShape()); + inputShape[axis], indicesType.getElementType(), indicesType.getRank()); auto indicesPlusDimSize = tosaBuilder.binaryOp(indices, dimSize); auto zero = tosaBuilder.getSplattedConst( - (int64_t)0, indicesType.getElementType(), indicesType.getShape()); + (int64_t)0, indicesType.getElementType(), indicesType.getRank()); auto indicesPositive = tosaBuilder.greaterEqual(indices, zero); auto newIndices = diff --git a/src/Conversion/ONNXToTOSA/Tensor/PaddingOp.cpp b/src/Conversion/ONNXToTOSA/Tensor/PaddingOp.cpp index c9fc9b6ea1..058cc43bad 100644 --- a/src/Conversion/ONNXToTOSA/Tensor/PaddingOp.cpp +++ b/src/Conversion/ONNXToTOSA/Tensor/PaddingOp.cpp @@ -40,6 +40,17 @@ class ONNXPadOpLoweringToTOSA : public OpConversionPattern { Value data = adaptor.getData(); Value pads = adaptor.getPads(); Value constValue = adaptor.getConstantValue(); + + auto dataType = dyn_cast(data.getType()); + if (!dataType || !dataType.hasStaticShape()) { + return rewriter.notifyMatchFailure(op, "input type has no static shape"); + } + + auto elementDtype = dataType.getElementType(); + if (!isa(elementDtype) && !isTOSAInt(elementDtype)) { + return rewriter.notifyMatchFailure(op, "unsupported type"); + } + if (!adaptor.getAxes().getDefiningOp()) { return rewriter.notifyMatchFailure(op, "only default axes are supported"); } @@ -78,27 +89,49 @@ class ONNXPadOpLoweringToTOSA : public OpConversionPattern { mlir::Type resultType = getTypeConverter()->convertType(op.getResult().getType()); - float valueFloat = 0.0F; if (!isa(constValue.getType())) { auto valueAttr = tosa::getValueFromTosaConst(constValue); - auto valueIt = valueAttr.getValues().begin(); - // Need float for F32 Type - float valueFloat = cast(*valueIt).getValueAsDouble(); - TosaBuilder tosaBuilder(rewriter, loc); - Value constTosaTensor = - tosaBuilder.getSplattedConst(valueFloat, valueAttr.getElementType()); + Value constTosaTensor; + if (isa(valueAttr.getElementType())) { + auto valueIt = valueAttr.getValues().begin(); + const float valueFloat = cast(*valueIt).getValueAsDouble(); + constTosaTensor = tosaBuilder.getSplattedConst( + valueFloat, valueAttr.getElementType(), 0); + } else { + assert(isTOSAInt(elementDtype) && "Already validated"); + auto valueIt = valueAttr.getValues().begin(); + auto valueAsAPInt = cast(*valueIt).getValue(); + auto asIntegerTy = cast(valueAttr.getElementType()); + if (asIntegerTy.isUnsigned()) { + constTosaTensor = tosaBuilder.getSplattedConst( + valueAsAPInt.getZExtValue(), asIntegerTy, 0); + } else { + constTosaTensor = tosaBuilder.getSplattedConst( + valueAsAPInt.getSExtValue(), asIntegerTy, 0); + } + } rewriter.replaceOpWithNewOp( op, resultType, data, padsList1, constTosaTensor); - } else { - auto constType = RankedTensorType::get({}, rewriter.getF32Type()); - auto constAttr = DenseElementsAttr::get(constType, valueFloat); - Value constTosaTensor = rewriter.create( - op->getLoc(), constType, constAttr); - rewriter.replaceOpWithNewOp( - op, resultType, data, padsList1, constTosaTensor); + } else { + auto constType = RankedTensorType::get({}, elementDtype); + + DenseElementsAttr constAttr; + if (isa(elementDtype)) { + constAttr = DenseElementsAttr::get(constType, 0.0F); + } else { + assert(isTOSAInt(elementDtype) && "Already validated"); + auto tyAsInt = cast(elementDtype); + constAttr = DenseElementsAttr::get(constType, + llvm::APInt(tyAsInt.getWidth(), 0, tyAsInt.getSignedness())); + } + + rewriter.replaceOpWithNewOp(op, resultType, data, + padsList1, + rewriter.create( + op->getLoc(), constType, constAttr)); } return success(); diff --git a/src/Conversion/ONNXToTOSA/Tensor/Shrink.cpp b/src/Conversion/ONNXToTOSA/Tensor/Shrink.cpp index d1969073c1..eb1d78926d 100644 --- a/src/Conversion/ONNXToTOSA/Tensor/Shrink.cpp +++ b/src/Conversion/ONNXToTOSA/Tensor/Shrink.cpp @@ -48,13 +48,13 @@ class ONNXShrinkOpLoweringToTOSA : public OpConversionPattern { const float lambdAsFloat = lambd.getValue().convertToFloat(); const float biasAsFloat = bias.getValue().convertToFloat(); auto lambdConstOp = tosaBuilder.getSplattedConst(lambdAsFloat, - inputRankedTensorTy.getElementType(), inputRankedTensorTy.getShape()); + inputRankedTensorTy.getElementType(), inputRankedTensorTy.getRank()); auto negatedLambdConstOp = tosaBuilder.getSplattedConst(-lambdAsFloat, - inputRankedTensorTy.getElementType(), inputRankedTensorTy.getShape()); + inputRankedTensorTy.getElementType(), inputRankedTensorTy.getRank()); auto biasConstOp = tosaBuilder.getSplattedConst(biasAsFloat, - inputRankedTensorTy.getElementType(), inputRankedTensorTy.getShape()); - auto zeroConstOp = tosaBuilder.getSplattedConst(0, - inputRankedTensorTy.getElementType(), inputRankedTensorTy.getShape()); + inputRankedTensorTy.getElementType(), inputRankedTensorTy.getRank()); + auto zeroConstOp = tosaBuilder.getSplattedConst( + 0, inputRankedTensorTy.getElementType(), inputRankedTensorTy.getRank()); // Formula to be implemented: // { x < -lambd, then y = x + bias diff --git a/src/Dialect/ONNX/Transforms/Decompose.cpp b/src/Dialect/ONNX/Transforms/Decompose.cpp index 22a5b7a179..59b1b96f58 100644 --- a/src/Dialect/ONNX/Transforms/Decompose.cpp +++ b/src/Dialect/ONNX/Transforms/Decompose.cpp @@ -30,6 +30,7 @@ #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "src/Compiler/CompilerOptions.hpp" #include "src/Dialect/ONNX/DialectBuilder.hpp" #include "src/Dialect/ONNX/ElementsAttr/ElementsAttrHelper.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" @@ -451,15 +452,14 @@ Value replaceSequenceAt( } bool shouldDecomposeConvTransposeOp(Value convTransposeResult) { -#ifdef ONNX_MLIR_DECOMP_ONNX_CONVTRANSPOSE + if (onnx_mlir::disableConvTransposeDecomposeOption) { + // Disable the ONNXConvTransposeOp decomposition patterns. + return false; + } ONNXConvTransposeOp op = mlir::cast(convTransposeResult.getDefiningOp()); return hasShapeAndRank(convTransposeResult) && hasStaticSpatialDims(op.getX()) && hasStaticSpatialDims(op.getW()); -#else - // Disable the ONNXConvTransposeOp decomposition patterns. - return false; -#endif } // Split on the specified axis. The length of each output is one. diff --git a/src/Dialect/ONNX/Transforms/Recompose.cpp b/src/Dialect/ONNX/Transforms/Recompose.cpp index e1b86eba28..5b86109d1f 100644 --- a/src/Dialect/ONNX/Transforms/Recompose.cpp +++ b/src/Dialect/ONNX/Transforms/Recompose.cpp @@ -244,7 +244,9 @@ struct RecomposeLayerNormFromMulPattern : public OpRewritePattern { mlir::dyn_cast(epsilon.getDefiningOp()); if (!epsilonOp) return reportFailure("RMS epsilon needs to be a constant"); - epsilonAttr = epsilonOp.getValueFloatAttr(); + const auto epsilonValue = getScalarValue(epsilonOp); + epsilonAttr = + FloatAttr::get(Float32Type::get(epsilonOp->getContext()), epsilonValue); // Check axes. if (!hasShapeAndRank(dd)) return reportFailure("RMS need rank and shape for input dd"); diff --git a/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir b/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir index 3247d16c70..2e46c13420 100644 --- a/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir +++ b/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir @@ -203,6 +203,16 @@ func.func @test_mul(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x1xf32>) -> t // ----- +func.func @test_mul_dynamic(%arg0: tensor, %arg1: tensor<13x?x?xf32>) -> tensor<13x?x?xf32> { + %0 = "onnx.Mul"(%arg0, %arg1) : (tensor, tensor<13x?x?xf32>) -> tensor<13x?x?xf32> + "func.return"(%0) : (tensor<13x?x?xf32>) -> () +// CHECK-LABEL: func @test_mul_dynamic +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor<13x?x?xf32>) -> tensor<13x?x?xf32> { +// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.mul [[PARAM_0_]], [[PARAM_1_]] {shift = 0 : i8} : (tensor, tensor<13x?x?xf32>) -> tensor<13x?x?xf32> +} + +// ----- + func.func @test_mul_rank_broadcast(%arg0: tensor<13x21x1xf32>, %arg1: tensor<21x1xf32>) -> tensor<13x21x1xf32> { %0 = "onnx.Mul"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<21x1xf32>) -> tensor<13x21x1xf32> "func.return"(%0) : (tensor<13x21x1xf32>) -> () @@ -235,6 +245,29 @@ func.func @test_div(%arg0: tensor<13x21x1xi32>, %arg1: tensor<13x21x1xi32>) -> t // ----- +func.func @test_div_dynamic(%arg0: tensor, %arg1: tensor<13x?x?xi32>) -> tensor<13x?x?xi32> { + %0 = "onnx.Div"(%arg0, %arg1) : (tensor, tensor<13x?x?xi32>) -> tensor<13x?x?xi32> + "func.return"(%0) : (tensor<13x?x?xi32>) -> () +// CHECK-LABEL: func @test_div_dynamic +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor<13x?x?xi32>) -> tensor<13x?x?xi32> { +// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.int_div [[PARAM_0_]], [[PARAM_1_]] : (tensor, tensor<13x?x?xi32>) -> tensor<13x?x?xi32> +} + +// ----- + +func.func @test_div_dynamic_float(%arg0: tensor, %arg1: tensor<13x?x?xf32>) -> tensor<13x?x?xf32> { + %0 = "onnx.Div"(%arg0, %arg1) : (tensor, tensor<13x?x?xf32>) -> tensor<13x?x?xf32> + "func.return"(%0) : (tensor<13x?x?xf32>) -> () +// CHECK-LABEL: func.func @test_div_dynamic_float +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor<13x?x?xf32>) -> tensor<13x?x?xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.reciprocal [[PARAM_1_]] : (tensor<13x?x?xf32>) -> tensor<13x?x?xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_0_]] {shift = 0 : i8} : (tensor, tensor<13x?x?xf32>) -> tensor<13x?x?xf32> +// CHECK: return [[VAR_1_]] : tensor<13x?x?xf32> +// CHECK: } +} + +// ----- + func.func @test_div_unsigned(%arg0: tensor<13x21x1xui8>, %arg1: tensor<13x21x1xui8>) -> tensor<13x21x1xui8> { %0 = "onnx.Div"(%arg0, %arg1) : (tensor<13x21x1xui8>, tensor<13x21x1xui8>) -> tensor<13x21x1xui8> "func.return"(%0) : (tensor<13x21x1xui8>) -> () @@ -347,6 +380,37 @@ func.func @test_selu(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { // ----- +func.func @test_selu_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %0 = "onnx.Selu"(%arg0) {alpha = 1.5 : f32, gamma = 2.0 : f32} : (tensor<*xf32>) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +// CHECK-LABEL: func.func @test_selu_unranked +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: [[VAR_0_:%.+]] = "onnx.Selu"([[PARAM_0_]]) {alpha = 1.500000e+00 : f32, gamma = 2.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return [[VAR_0_]] : tensor<*xf32> +// CHECK: } +} + +// ----- + +func.func @test_selu_dynamic(%arg0: tensor) -> tensor { + %0 = "onnx.Selu"(%arg0) {alpha = 1.5 : f32, gamma = 2.0 : f32} : (tensor) -> tensor + func.return %0 : tensor +// CHECK-LABEL: func.func @test_selu_dynamic +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { +// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<1.500000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> +// CHECK-DAG: [[VAR_3_:%.+]] = tosa.exp [[PARAM_0_]] : (tensor) -> tensor +// CHECK: [[VAR_4_:%.+]] = tosa.mul [[VAR_3_]], [[VAR_0_]] {shift = 0 : i8} : (tensor, tensor<1x1x1xf32>) -> tensor +// CHECK-DAG: [[VAR_5_:%.+]] = tosa.sub [[VAR_4_]], [[VAR_0_]] : (tensor, tensor<1x1x1xf32>) -> tensor +// CHECK-DAG: [[VAR_6_:%.+]] = tosa.greater [[PARAM_0_]], [[VAR_2_]] : (tensor, tensor<1x1x1xf32>) -> tensor +// CHECK: [[VAR_7_:%.+]] = tosa.select [[VAR_6_]], [[PARAM_0_]], [[VAR_5_]] : (tensor, tensor, tensor) -> tensor +// CHECK: [[VAR_8_:%.+]] = tosa.mul [[VAR_7_]], [[VAR_1_]] {shift = 0 : i8} : (tensor, tensor<1x1x1xf32>) -> tensor +// CHECK: return [[VAR_8_]] : tensor +} + +// ----- + func.func @test_softplus(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { %0 = "onnx.Softplus"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> func.return %0 : tensor<13x21x3xf32> @@ -358,6 +422,32 @@ func.func @test_softplus(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { // CHECK: return [[VAR_3_]] : tensor<13x21x3xf32> } +// ----- + +func.func @test_softplus_dynamic(%arg0: tensor) -> tensor { + %0 = "onnx.Softplus"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +// CHECK-LABEL: func.func @test_softplus_dynamic +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { +// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.exp [[PARAM_0_]] : (tensor) -> tensor +// CHECK: [[VAR_2_:%.+]] = tosa.add [[VAR_1_]], [[VAR_0_]] : (tensor, tensor<1x1x1xf32>) -> tensor +// CHECK: [[VAR_3_:%.+]] = tosa.log [[VAR_2_]] : (tensor) -> tensor +// CHECK: return [[VAR_3_]] : tensor +} + +// ----- + +func.func @test_softplus_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %0 = "onnx.Softplus"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +// CHECK-LABEL: func.func @test_softplus_unranked +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: [[VAR_0_:%.+]] = "onnx.Softplus"([[PARAM_0_]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return [[VAR_0_]] : tensor<*xf32> +} + + // ----- func.func @test_thresholdedrelu(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { @@ -386,6 +476,31 @@ func.func @test_thresholdedrelu_default_value(%arg0: tensor<13x21x3xf32>) -> ten // ----- +func.func @test_thresholded_relu_dynamic(%arg0: tensor) -> tensor { + %0 = "onnx.ThresholdedRelu"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +// CHECK-LABEL: func.func @test_thresholded_relu_dynamic +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { +// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> +// CHECK: [[VAR_2_:%.+]] = tosa.greater [[PARAM_0_]], [[VAR_0_]] : (tensor, tensor<1x1x1xf32>) -> tensor +// CHECK: [[VAR_3_:%.+]] = tosa.select [[VAR_2_]], [[PARAM_0_]], [[VAR_1_]] : (tensor, tensor, tensor<1x1x1xf32>) -> tensor +// CHECK: return [[VAR_3_]] : tensor +} + +// ----- + +func.func @test_thresholded_relu_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %0 = "onnx.ThresholdedRelu"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +// CHECK-LABEL: func.func @test_thresholded_relu_unranked +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: [[VAR_0_:%.+]] = "onnx.ThresholdedRelu"([[PARAM_0_]]) {alpha = 1.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return [[VAR_0_]] : tensor<*xf32> +} + +// ----- + func.func @test_sigmoid(%arg0 : tensor<10x10xf32>) -> tensor<10x10xf32> { %0 = "onnx.Sigmoid"(%arg0) : (tensor<10x10xf32>) -> tensor<10x10xf32> "func.return"(%0) : (tensor<10x10xf32>) -> () @@ -732,6 +847,32 @@ func.func @test_hardsigmoid_f16(%arg0: tensor<3xf16>) -> tensor<3xf16> { // ----- +func.func @test_hardsigmoid_dynamic(%arg0: tensor) -> tensor { + %0 = "onnx.HardSigmoid"(%arg0) {alpha = 0.166666672 : f32, beta = 5.000000e-01 : f32} : (tensor) -> tensor + return %0 : tensor +// CHECK-LABEL: func.func @test_hardsigmoid_dynamic +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { +// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<3.000000e+00> : tensor<1x1x1xf16>}> : () -> tensor<1x1x1xf16> +// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<1.666260e-01> : tensor<1x1x1xf16>}> : () -> tensor<1x1x1xf16> +// CHECK: [[VAR_2_:%.+]] = tosa.add [[PARAM_0_]], [[VAR_0_]] : (tensor, tensor<1x1x1xf16>) -> tensor +// CHECK: [[VAR_3_:%.+]] = tosa.clamp [[VAR_2_]] {max_fp = 6.000000e+00 : f32, max_int = 6 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor) -> tensor +// CHECK: [[VAR_4_:%.+]] = tosa.mul [[VAR_3_]], [[VAR_1_]] {shift = 0 : i8} : (tensor, tensor<1x1x1xf16>) -> tensor +// CHECK: return [[VAR_4_]] : tensor +} + +// ----- + +func.func @test_hardsigmoid_unranked(%arg0: tensor<*xf16>) -> tensor<*xf16> { + %0 = "onnx.HardSigmoid"(%arg0) {alpha = 0.166666672 : f32, beta = 5.000000e-01 : f32} : (tensor<*xf16>) -> tensor<*xf16> + return %0 : tensor<*xf16> +// CHECK-LABEL: func.func @test_hardsigmoid_unranked +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf16>) -> tensor<*xf16> { +// CHECK: [[VAR_0_:%.+]] = "onnx.HardSigmoid"([[PARAM_0_]]) {alpha = 0.166666672 : f32, beta = 5.000000e-01 : f32} : (tensor<*xf16>) -> tensor<*xf16> +// CHECK: return [[VAR_0_]] : tensor<*xf16> +} + +// ----- + func.func @test_elu_f32(%arg0: tensor<3xf32>) -> tensor<3xf32> { %0 = "onnx.Elu"(%arg0) {alpha = 0.166666672 : f32} : (tensor<3xf32>) -> tensor<3xf32> return %0 : tensor<3xf32> @@ -764,6 +905,26 @@ func.func @test_elu_f16(%arg0: tensor<3xf16>) -> tensor<3xf16> { // CHECK: return [[VAR_7_]] } +// ----- + +func.func @test_elu_unranked(%arg0: tensor<*xf32>) -> tensor<3xf32> { + %0 = "onnx.Elu"(%arg0) {alpha = 0.166666672 : f32} : (tensor<*xf32>) -> tensor<3xf32> + return %0 : tensor<3xf32> +// CHECK-LABEL: func.func @test_elu_unranked +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>) -> tensor<3xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<0.166666672> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK-DAG: [[VAR_3_:%.+]] = tosa.exp [[PARAM_0_]] : (tensor<*xf32>) -> tensor<3xf32> +// CHECK: [[VAR_4_:%.+]] = tosa.sub [[VAR_3_]], [[VAR_0_]] : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xf32> +// CHECK-DAG: [[VAR_5_:%.+]] = tosa.mul [[VAR_4_]], [[VAR_1_]] {shift = 0 : i8} : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xf32> +// CHECK-DAG: [[VAR_6_:%.+]] = tosa.greater_equal [[PARAM_0_]], [[VAR_2_]] : (tensor<*xf32>, tensor<1xf32>) -> tensor<*xi1> +// CHECK: [[VAR_7_:%.+]] = tosa.select [[VAR_6_]], [[PARAM_0_]], [[VAR_5_]] : (tensor<*xi1>, tensor<*xf32>, tensor<3xf32>) -> tensor<3xf32> +// CHECK: return [[VAR_7_]] : tensor<3xf32> +// CHECK: } +} + + // ----- func.func @test_and(%arg0: tensor<13x21x1xi1>, %arg1: tensor<13x21x1xi1>) -> tensor<13x21x1xi1> { diff --git a/test/mlir/conversion/onnx_to_tosa/Tensor/Padding.mlir b/test/mlir/conversion/onnx_to_tosa/Tensor/Padding.mlir index 5e07001a83..ea0b6d662c 100644 --- a/test/mlir/conversion/onnx_to_tosa/Tensor/Padding.mlir +++ b/test/mlir/conversion/onnx_to_tosa/Tensor/Padding.mlir @@ -1,55 +1,137 @@ // RUN: onnx-mlir-opt --shape-inference --convert-onnx-to-tosa %s -split-input-file | FileCheck %s -func.func @test_pad(%arg0: tensor<20x16x44x32xf32>) -> tensor<24x22x52x42xf32> { +func.func @test_pad_f32(%arg0: tensor<20x16x44x32xf32>) -> tensor<24x22x52x42xf32> { %noval = "onnx.NoValue"() {value} : () -> none %0 = "onnx.Constant"() {value = dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi64>} : () -> tensor<8xi64> %1 = "onnx.Constant"() {value = dense<[4.5000]> : tensor<1xf32>} : () -> tensor<1xf32> %2 = "onnx.Pad"(%arg0, %0, %1, %noval) {mode = "constant"} : (tensor<20x16x44x32xf32>, tensor<8xi64>, tensor<1xf32>, none) -> tensor<24x22x52x42xf32> return %2 : tensor<24x22x52x42xf32> -// CHECK-LABEL: test_pad +// CHECK-LABEL: test_pad_f32 // CHECK: %[[VAR0:.*]] = "tosa.const"() <{value = dense<[{{\[}}0, 4], [1, 5], [2, 6], [3, 7]]> : tensor<4x2xi64>}> : () -> tensor<4x2xi64> // CHECK: %[[VAR1:.*]] = "tosa.const"() <{value = dense<4.500000e+00> : tensor}> : () -> tensor // CHECK: %[[VAR2:.*]] = tosa.pad %arg0, %[[VAR0]], %[[VAR1]] } // ----- -func.func @test_no_pad(%arg0: tensor<20x16x44x32xf32>) -> tensor<20x16x44x32xf32> { +func.func @test_no_pad_f32(%arg0: tensor<20x16x44x32xf32>) -> tensor<20x16x44x32xf32> { %noval = "onnx.NoValue"() {value} : () -> none %0 = "onnx.Constant"() {value = dense<[0, 0, 0, 0, 0, 0, 0, 0]> : tensor<8xi64>} : () -> tensor<8xi64> %1 = "onnx.Constant"() {value = dense<[4.5000]> : tensor<1xf32>} : () -> tensor<1xf32> %2 = "onnx.Pad"(%arg0, %0, %1, %noval) {mode = "constant"} : (tensor<20x16x44x32xf32>, tensor<8xi64>, tensor<1xf32>, none) -> tensor<20x16x44x32xf32> return %2 : tensor<20x16x44x32xf32> -// CHECK-LABEL: test_no_pad +// CHECK-LABEL: test_no_pad_f32 // CHECK: return %arg0 } // ----- -func.func @test_novalue_pad(%arg0: tensor<20x16x44x32xf32>) -> tensor<20x16x45x33xf32> { +func.func @test_novalue_pad_f32(%arg0: tensor<20x16x44x32xf32>) -> tensor<20x16x45x33xf32> { %0 = "onnx.Constant"() {value = dense<[0, 0, 1, 1, 0, 0, 0, 0]> : tensor<8xi64>} : () -> tensor<8xi64> %1 = "onnx.NoValue"() {value} : () -> none %2 = "onnx.Pad"(%arg0, %0, %1, %1) {mode = "constant"} : (tensor<20x16x44x32xf32>, tensor<8xi64>, none, none) -> tensor<20x16x45x33xf32> return %2 : tensor<20x16x45x33xf32> -// CHECK-LABEL: test_novalue_pad +// CHECK-LABEL: test_novalue_pad_f32 // CHECK: %[[VAR0:.*]] = "tosa.const"() <{value = dense<[{{\[}}0, 0], [0, 0], [1, 0], [1, 0]]> : tensor<4x2xi64>}> : () -> tensor<4x2xi64> // CHECK: %[[VAR1:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor}> : () -> tensor // CHECK: tosa.pad %arg0, %[[VAR0]], %[[VAR1]] } // ----- -func.func @test_novalue_no_pad(%arg0: tensor<20x16x44x32xf32>) -> tensor<20x16x44x32xf32> { +func.func @test_novalue_no_pad_f32(%arg0: tensor<20x16x44x32xf32>) -> tensor<20x16x44x32xf32> { %0 = "onnx.Constant"() {value = dense<[0, 0, 0, 0, 0, 0, 0, 0]> : tensor<8xi64>} : () -> tensor<8xi64> %1 = "onnx.NoValue"() {value} : () -> none %2 = "onnx.Pad"(%arg0, %0, %1, %1) {mode = "constant"} : (tensor<20x16x44x32xf32>, tensor<8xi64>, none, none) -> tensor<20x16x44x32xf32> return %2 : tensor<20x16x44x32xf32> -// CHECK-LABEL: test_novalue_no_pad +// CHECK-LABEL: test_novalue_no_pad_f32 // CHECK: return %arg0 } // ----- -func.func @test_no_const_pad(%arg0: tensor<20x16x44x32xf32>, %arg1: tensor<8xi64>, %arg2: tensor<1xf32>) -> tensor<20x16x44x32xf32> { +func.func @test_no_const_pad_f32(%arg0: tensor<20x16x44x32xf32>, %arg1: tensor<8xi64>, %arg2: tensor<1xf32>) -> tensor<20x16x44x32xf32> { %noval = "onnx.NoValue"() {value} : () -> none %2 = "onnx.Pad"(%arg0, %arg1, %arg2, %noval) {mode = "constant"} : (tensor<20x16x44x32xf32>, tensor<8xi64>, tensor<1xf32>, none) -> tensor<20x16x44x32xf32> return %2 : tensor<20x16x44x32xf32> -// CHECK-LABEL: test_no_const_pad +// CHECK-LABEL: test_no_const_pad_f32 // CHECK: "onnx.Pad" } + +// ----- +func.func @test_pad_i64(%arg0: tensor<20x16x44x32xi64>) -> tensor<24x22x52x42xi64> { + %noval = "onnx.NoValue"() {value} : () -> none + %0 = "onnx.Constant"() {value = dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi64>} : () -> tensor<8xi64> + %1 = "onnx.Constant"() {value = dense<[4]> : tensor<1xi64>} : () -> tensor<1xi64> + %2 = "onnx.Pad"(%arg0, %0, %1, %noval) {mode = "constant"} : (tensor<20x16x44x32xi64>, tensor<8xi64>, tensor<1xi64>, none) -> tensor<24x22x52x42xi64> + return %2 : tensor<24x22x52x42xi64> +// CHECK-LABEL: test_pad_i64 +// CHECK: %[[VAR0:.*]] = "tosa.const"() <{value = dense<[{{\[}}0, 4], [1, 5], [2, 6], [3, 7]]> : tensor<4x2xi64>}> : () -> tensor<4x2xi64> +// CHECK: %[[VAR1:.*]] = "tosa.const"() <{value = dense<4> : tensor}> : () -> tensor +// CHECK: %[[VAR2:.*]] = tosa.pad %arg0, %[[VAR0]], %[[VAR1]] +} + +// ----- +func.func @test_no_pad_i64(%arg0: tensor<20x16x44x32xi64>) -> tensor<20x16x44x32xi64> { + %noval = "onnx.NoValue"() {value} : () -> none + %0 = "onnx.Constant"() {value = dense<[0, 0, 0, 0, 0, 0, 0, 0]> : tensor<8xi64>} : () -> tensor<8xi64> + %1 = "onnx.Constant"() {value = dense<[4]> : tensor<1xi64>} : () -> tensor<1xi64> + %2 = "onnx.Pad"(%arg0, %0, %1, %noval) {mode = "constant"} : (tensor<20x16x44x32xi64>, tensor<8xi64>, tensor<1xi64>, none) -> tensor<20x16x44x32xi64> + return %2 : tensor<20x16x44x32xi64> +// CHECK-LABEL: test_no_pad_i64 +// CHECK: return %arg0 +} + +// ----- +func.func @test_novalue_pad_i64(%arg0: tensor<20x16x44x32xi64>) -> tensor<20x16x45x33xi64> { + %0 = "onnx.Constant"() {value = dense<[0, 0, 1, 1, 0, 0, 0, 0]> : tensor<8xi64>} : () -> tensor<8xi64> + %1 = "onnx.NoValue"() {value} : () -> none + %2 = "onnx.Pad"(%arg0, %0, %1, %1) {mode = "constant"} : (tensor<20x16x44x32xi64>, tensor<8xi64>, none, none) -> tensor<20x16x45x33xi64> + return %2 : tensor<20x16x45x33xi64> +// CHECK-LABEL: test_novalue_pad_i64 +// CHECK: %[[VAR0:.*]] = "tosa.const"() <{value = dense<[{{\[}}0, 0], [0, 0], [1, 0], [1, 0]]> : tensor<4x2xi64>}> : () -> tensor<4x2xi64> +// CHECK: %[[VAR1:.*]] = "tosa.const"() <{value = dense<0> : tensor}> : () -> tensor +// CHECK: tosa.pad %arg0, %[[VAR0]], %[[VAR1]] +} + +// ----- +func.func @test_novalue_no_pad_i64(%arg0: tensor<20x16x44x32xi64>) -> tensor<20x16x44x32xi64> { + %0 = "onnx.Constant"() {value = dense<[0, 0, 0, 0, 0, 0, 0, 0]> : tensor<8xi64>} : () -> tensor<8xi64> + %1 = "onnx.NoValue"() {value} : () -> none + %2 = "onnx.Pad"(%arg0, %0, %1, %1) {mode = "constant"} : (tensor<20x16x44x32xi64>, tensor<8xi64>, none, none) -> tensor<20x16x44x32xi64> + return %2 : tensor<20x16x44x32xi64> +// CHECK-LABEL: test_novalue_no_pad_i64 +// CHECK: return %arg0 +} + +// ----- +func.func @test_no_const_pad_i64(%arg0: tensor<20x16x44x32xi64>, %arg1: tensor<8xi64>, %arg2: tensor<1xi64>) -> tensor<20x16x44x32xi64> { + %noval = "onnx.NoValue"() {value} : () -> none + %2 = "onnx.Pad"(%arg0, %arg1, %arg2, %noval) {mode = "constant"} : (tensor<20x16x44x32xi64>, tensor<8xi64>, tensor<1xi64>, none) -> tensor<20x16x44x32xi64> + return %2 : tensor<20x16x44x32xi64> +// CHECK-LABEL: test_no_const_pad_i64 +// CHECK: "onnx.Pad" +} + +// ----- +func.func @test_pad_ui32(%arg0: tensor<20x16x44x32xui32>) -> tensor<24x22x52x42xui32> { + %noval = "onnx.NoValue"() {value} : () -> none + %0 = "onnx.Constant"() {value = dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi64>} : () -> tensor<8xi64> + %1 = "onnx.Constant"() {value = dense<[4]> : tensor<1xui32>} : () -> tensor<1xui32> + %2 = "onnx.Pad"(%arg0, %0, %1, %noval) {mode = "constant"} : (tensor<20x16x44x32xui32>, tensor<8xi64>, tensor<1xui32>, none) -> tensor<24x22x52x42xui32> + return %2 : tensor<24x22x52x42xui32> +// CHECK-LABEL: test_pad_ui32 +// CHECK: %[[VAR0:.*]] = "tosa.const"() <{value = dense<[{{\[}}0, 4], [1, 5], [2, 6], [3, 7]]> : tensor<4x2xi64>}> : () -> tensor<4x2xi64> +// CHECK: %[[VAR1:.*]] = "tosa.const"() <{value = dense<4> : tensor}> : () -> tensor +// CHECK: %[[VAR2:.*]] = tosa.pad %arg0, %[[VAR0]], %[[VAR1]] +} + +// ----- +func.func @test_pad_bf16(%arg0: tensor<20x16x44x32xbf16>) -> tensor<24x22x52x42xbf16> { + %noval = "onnx.NoValue"() {value} : () -> none + %0 = "onnx.Constant"() {value = dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi64>} : () -> tensor<8xi64> + %1 = "onnx.Constant"() {value = dense<[4.500000e+00]> : tensor<1xbf16>} : () -> tensor<1xbf16> + %2 = "onnx.Pad"(%arg0, %0, %1, %noval) {mode = "constant"} : (tensor<20x16x44x32xbf16>, tensor<8xi64>, tensor<1xbf16>, none) -> tensor<24x22x52x42xbf16> + return %2 : tensor<24x22x52x42xbf16> +// CHECK-LABEL: test_pad_bf16 +// CHECK: %[[VAR0:.*]] = "tosa.const"() <{value = dense<[{{\[}}0, 4], [1, 5], [2, 6], [3, 7]]> : tensor<4x2xi64>}> : () -> tensor<4x2xi64> +// CHECK: %[[VAR1:.*]] = "tosa.const"() <{value = dense<4.500000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAR2:.*]] = tosa.pad %arg0, %[[VAR0]], %[[VAR1]] +} + diff --git a/test/mlir/lit.cfg.py b/test/mlir/lit.cfg.py index 24d855c09e..bf1706a95f 100644 --- a/test/mlir/lit.cfg.py +++ b/test/mlir/lit.cfg.py @@ -51,6 +51,3 @@ # execution based on the available targets for arch in config.targets_to_build.split(): config.available_features.add(arch.lower()) - -if config.decomp_onnx_convtranspose: - config.available_features.add("decomp_onnx_convtranspose") diff --git a/test/mlir/lit.site.cfg.py.in b/test/mlir/lit.site.cfg.py.in index 243a9e0408..54b4839867 100644 --- a/test/mlir/lit.site.cfg.py.in +++ b/test/mlir/lit.site.cfg.py.in @@ -10,7 +10,6 @@ config.onnx_mlir_obj_root = r"@ONNX_MLIR_BIN_ROOT@" config.enable_stablehlo = @ONNX_MLIR_STABLEHLO_ENABLED@ config.enable_nnpa= 0x0@NNPA_LIT_ENABLED@ -config.decomp_onnx_convtranspose = @ONNX_MLIR_DECOMP_ONNX_CONVTRANSPOSE_ENABLED@ # Support substitution of the tools_dir with user parameters. This is # used when we can't determine the tool dir at configuration time. diff --git a/test/mlir/onnx/onnx_decompose_convtranspose.mlir b/test/mlir/onnx/onnx_decompose_convtranspose.mlir index c539f766cc..b88dd4713a 100644 --- a/test/mlir/onnx/onnx_decompose_convtranspose.mlir +++ b/test/mlir/onnx/onnx_decompose_convtranspose.mlir @@ -1,6 +1,5 @@ // RUN: onnx-mlir-opt --shape-inference --decompose-onnx %s -split-input-file | FileCheck %s -// REQUIRES: decomp_onnx_convtranspose // ----- diff --git a/test/mlir/onnx/onnx_decompose_convtranspose_disable.mlir b/test/mlir/onnx/onnx_decompose_convtranspose_disable.mlir new file mode 100644 index 0000000000..b31c6b28a6 --- /dev/null +++ b/test/mlir/onnx/onnx_decompose_convtranspose_disable.mlir @@ -0,0 +1,104 @@ +// RUN: onnx-mlir-opt --shape-inference --decompose-onnx --disable-convtranspose-decompose %s -split-input-file | FileCheck %s + + +// ----- + +// Test unit strides. Only convert weight tensor + + func.func @test_convtrans_unitstrides(%arg0: tensor<1x1x3x3xf32>, %arg1: tensor<1x2x3x3xf32>) -> tensor<1x2x5x5xf32> { + %0 = "onnx.NoValue"() {value} : () -> none + %1 = "onnx.ConvTranspose"(%arg0, %arg1, %0) {auto_pad = "NOTSET", group = 1 : si64} : (tensor<1x1x3x3xf32>, tensor<1x2x3x3xf32>, none) -> tensor<1x2x5x5xf32> + onnx.Return %1 : tensor<1x2x5x5xf32> +// CHECK-LABEL: func.func @test_convtrans_unitstrides( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x3x3xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2x3x3xf32>) -> tensor<1x2x5x5xf32> { +// CHECK: %[[VAL_2:.*]] = "onnx.NoValue"() {value} : () -> none +// CHECK: %[[VAL_3:.*]] = "onnx.ConvTranspose"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) {auto_pad = "NOTSET", group = 1 : si64} : (tensor<1x1x3x3xf32>, tensor<1x2x3x3xf32>, none) -> tensor<1x2x5x5xf32> +// CHECK: onnx.Return %[[VAL_3]] : tensor<1x2x5x5xf32> +// CHECK: } + } + +// ----- + +// Test 1d input + + func.func @test_convtrans1d_unitstrides(%arg0: tensor<1x1x3xf32>, %arg1: tensor<1x2x3xf32>) -> tensor<1x2x5xf32> { + %0 = "onnx.NoValue"() {value} : () -> none + %1 = "onnx.ConvTranspose"(%arg0, %arg1, %0) {auto_pad = "NOTSET", group = 1 : si64} : (tensor<1x1x3xf32>, tensor<1x2x3xf32>, none) -> tensor<1x2x5xf32> + onnx.Return %1 : tensor<1x2x5xf32> +// CHECK-LABEL: func.func @test_convtrans1d_unitstrides( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x3xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2x3xf32>) -> tensor<1x2x5xf32> { +// CHECK: %[[VAL_2:.*]] = "onnx.NoValue"() {value} : () -> none +// CHECK: %[[VAL_3:.*]] = "onnx.ConvTranspose"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) {auto_pad = "NOTSET", group = 1 : si64} : (tensor<1x1x3xf32>, tensor<1x2x3xf32>, none) -> tensor<1x2x5xf32> +// CHECK: onnx.Return %[[VAL_3]] : tensor<1x2x5xf32> +// CHECK: } + } + +// ----- + +// Test 3d input + + func.func @test_convtrans3d_unitstrides(%arg0: tensor<1x1x3x4x5xf32>, %arg1: tensor<1x2x3x3x3xf32>) -> tensor<1x2x5x6x7xf32> { + %0 = "onnx.NoValue"() {value} : () -> none + %1 = "onnx.ConvTranspose"(%arg0, %arg1, %0) {auto_pad = "NOTSET", group = 1 : si64} : (tensor<1x1x3x4x5xf32>, tensor<1x2x3x3x3xf32>, none) -> tensor<1x2x5x6x7xf32> + onnx.Return %1 : tensor<1x2x5x6x7xf32> +// CHECK-LABEL: func.func @test_convtrans3d_unitstrides( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x3x4x5xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2x3x3x3xf32>) -> tensor<1x2x5x6x7xf32> { +// CHECK: %[[VAL_2:.*]] = "onnx.NoValue"() {value} : () -> none +// CHECK: %[[VAL_3:.*]] = "onnx.ConvTranspose"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) {auto_pad = "NOTSET", group = 1 : si64} : (tensor<1x1x3x4x5xf32>, tensor<1x2x3x3x3xf32>, none) -> tensor<1x2x5x6x7xf32> +// CHECK: onnx.Return %[[VAL_3]] : tensor<1x2x5x6x7xf32> +// CHECK: } + } + +// ----- + +// Test non unit strides. Added pads between elements in input data. + + func.func @test_convtrans_strides(%arg0: tensor<1x1x3x3xf32>, %arg1: tensor<1x2x3x3xf32>) -> tensor<1x2x7x3xf32> { + %0 = "onnx.NoValue"() {value} : () -> none + %1 = "onnx.ConvTranspose"(%arg0, %arg1, %0) {auto_pad = "NOTSET", group = 1 : si64, pads = [1, 2, 1, 2], strides = [3, 2]} : (tensor<1x1x3x3xf32>, tensor<1x2x3x3xf32>, none) -> tensor<1x2x7x3xf32> + onnx.Return %1 : tensor<1x2x7x3xf32> +// CHECK-LABEL: func.func @test_convtrans_strides( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x3x3xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2x3x3xf32>) -> tensor<1x2x7x3xf32> { +// CHECK: %[[VAL_2:.*]] = "onnx.NoValue"() {value} : () -> none +// CHECK: %[[VAL_3:.*]] = "onnx.ConvTranspose"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) {auto_pad = "NOTSET", group = 1 : si64, pads = [1, 2, 1, 2], strides = [3, 2]} : (tensor<1x1x3x3xf32>, tensor<1x2x3x3xf32>, none) -> tensor<1x2x7x3xf32> +// CHECK: onnx.Return %[[VAL_3]] : tensor<1x2x7x3xf32> +// CHECK: } + } + +// ----- + +// Test output_padding. Additional pads are inserted after Conv op + + func.func @test_convtrans_outputpadding(%arg0: tensor<1x1x3x3xf32>, %arg1: tensor<1x2x3x3xf32>) -> tensor<1x2x10x8xf32> { + %0 = "onnx.NoValue"() {value} : () -> none + %1 = "onnx.ConvTranspose"(%arg0, %arg1, %0) {auto_pad = "NOTSET", group = 1 : si64, output_shape = [10, 8], strides = [3, 2]} : (tensor<1x1x3x3xf32>, tensor<1x2x3x3xf32>, none) -> tensor<1x2x10x8xf32> + onnx.Return %1 : tensor<1x2x10x8xf32> +// CHECK-LABEL: func.func @test_convtrans_outputpadding( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x3x3xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2x3x3xf32>) -> tensor<1x2x10x8xf32> { +// CHECK: %[[VAL_2:.*]] = "onnx.NoValue"() {value} : () -> none +// CHECK: %[[VAL_3:.*]] = "onnx.ConvTranspose"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) {auto_pad = "NOTSET", group = 1 : si64, output_shape = [10, 8], strides = [3, 2]} : (tensor<1x1x3x3xf32>, tensor<1x2x3x3xf32>, none) -> tensor<1x2x10x8xf32> +// CHECK: onnx.Return %[[VAL_3]] : tensor<1x2x10x8xf32> +// CHECK: } + } + +// ----- + +// Test for unknown dimension in spatial dimensions + + func.func @test_convtranspose_unknown_spatial_dim(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "onnx.NoValue"() {value} : () -> none + %1 = "onnx.ConvTranspose"(%arg0, %arg1, %0) {auto_pad = "NOTSET", group = 1 : si64, kernel_shape = [3, 3], onnx_node_name = "test", output_padding = [1, 1], output_shape = [10, 8], strides = [3, 2]} : (tensor, tensor, none) -> tensor + onnx.Return %1 : tensor +// CHECK-LABEL: func.func @test_convtranspose_unknown_spatial_dim( +// CHECK-SAME: %[[VAL_0:.*]]: tensor, +// CHECK-SAME: %[[VAL_1:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_2:.*]] = "onnx.NoValue"() {value} : () -> none +// CHECK: %[[VAL_3:.*]] = "onnx.ConvTranspose"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) {auto_pad = "NOTSET", group = 1 : si64, kernel_shape = [3, 3], onnx_node_name = "test", output_padding = [1, 1], output_shape = [10, 8], strides = [3, 2]} : (tensor, tensor, none) -> tensor +// CHECK: onnx.Return %[[VAL_3]] : tensor +// CHECK: } + } diff --git a/test/mlir/onnx/onnx_recompose.mlir b/test/mlir/onnx/onnx_recompose.mlir index 34f1e1b1bd..264d88f18f 100644 --- a/test/mlir/onnx/onnx_recompose.mlir +++ b/test/mlir/onnx/onnx_recompose.mlir @@ -6,7 +6,7 @@ func.func @layernorm_with_spurious_adds(%input: tensor<1x384x768xf32>, %scale: tensor<768xf32>, %bias: tensor<768xf32>) -> (tensor<1x384x768xf32>) { %x = "onnx.Add"(%input, %bias) : (tensor<1x384x768xf32>, tensor<768xf32>) -> tensor<1x384x768xf32> - %eps = onnx.Constant dense<9.99999974E-6> : tensor + %eps = onnx.Constant dense<1.2E+0> : tensor %mean = "onnx.ReduceMeanV13"(%x) {axes = [-1], keepdims = 1 : si64} : (tensor<1x384x768xf32>) -> tensor<1x384x1xf32> %d = "onnx.Sub"(%x, %mean) : (tensor<1x384x768xf32>, tensor<1x384x1xf32>) -> tensor<1x384x768xf32> %dd = "onnx.Mul"(%d, %d) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> @@ -23,7 +23,7 @@ func.func @layernorm_with_spurious_adds(%input: tensor<1x384x768xf32>, %scale: t // CHECK-LABEL: func.func @layernorm_with_spurious_adds // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>, [[PARAM_2_:%.+]]: tensor<768xf32>) -> tensor<1x384x768xf32> { // CHECK: [[VAR_0_:%.+]] = "onnx.Add"([[PARAM_0_]], [[PARAM_2_]]) : (tensor<1x384x768xf32>, tensor<768xf32>) -> tensor<1x384x768xf32> -// CHECK: [[Y_:%.+]], [[Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[VAR_0_]], [[PARAM_1_]], [[PARAM_2_]]) {axis = 2 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, tensor<768xf32>) -> (tensor<1x384x768xf32>, none, none) +// CHECK: [[Y_:%.+]], [[Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[VAR_0_]], [[PARAM_1_]], [[PARAM_2_]]) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, tensor<768xf32>) -> (tensor<1x384x768xf32>, none, none) // CHECK: [[VAR_1_:%.+]] = "onnx.Add"([[Y_]], [[PARAM_2_]]) : (tensor<1x384x768xf32>, tensor<768xf32>) -> tensor<1x384x768xf32> // CHECK: return [[VAR_1_]] : tensor<1x384x768xf32> // CHECK: } @@ -33,7 +33,7 @@ func.func @layernorm_with_spurious_adds(%input: tensor<1x384x768xf32>, %scale: t // Layernorm without bias func.func @layernorm_without_bias(%x: tensor<1x384x768xf32>, %scale: tensor<768xf32>, %bias: tensor<768xf32>) -> (tensor<1x384x768xf32>) { - %eps = onnx.Constant dense<9.99999974E-6> : tensor + %eps = onnx.Constant dense<1.2E+0> : tensor %mean = "onnx.ReduceMeanV13"(%x) {axes = [-1], keepdims = 1 : si64} : (tensor<1x384x768xf32>) -> tensor<1x384x1xf32> %d = "onnx.Sub"(%x, %mean) : (tensor<1x384x768xf32>, tensor<1x384x1xf32>) -> tensor<1x384x768xf32> %dd = "onnx.Mul"(%d, %d) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> @@ -47,7 +47,7 @@ func.func @layernorm_without_bias(%x: tensor<1x384x768xf32>, %scale: tensor<768x // CHECK-LABEL: func.func @layernorm_without_bias // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>, [[PARAM_2_:%.+]]: tensor<768xf32>) -> tensor<1x384x768xf32> { // CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK: [[Y_:%.+]], [[Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[PARAM_0_]], [[PARAM_1_]], [[VAR_0_]]) {axis = 2 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, none) -> (tensor<1x384x768xf32>, none, none) +// CHECK: [[Y_:%.+]], [[Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[PARAM_0_]], [[PARAM_1_]], [[VAR_0_]]) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, none) -> (tensor<1x384x768xf32>, none, none) // CHECK: return [[Y_]] : tensor<1x384x768xf32> // CHECK: } } @@ -55,7 +55,7 @@ func.func @layernorm_without_bias(%x: tensor<1x384x768xf32>, %scale: tensor<768x // ----- func.func @layernorm_without_bias_first_reduce_unsuitable_axis(%x: tensor<1x384x768xf32>, %scale: tensor<768xf32>, %bias: tensor<768xf32>) -> (tensor<1x384x768xf32>) { - %eps = onnx.Constant dense<9.99999974E-6> : tensor + %eps = onnx.Constant dense<1.2E+0> : tensor %mean = "onnx.ReduceMeanV13"(%x) {axes = [-2], keepdims = 1 : si64} : (tensor<1x384x768xf32>) -> tensor<1x384x1xf32> %d = "onnx.Sub"(%x, %mean) : (tensor<1x384x768xf32>, tensor<1x384x1xf32>) -> tensor<1x384x768xf32> %dd = "onnx.Mul"(%d, %d) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> @@ -71,7 +71,7 @@ func.func @layernorm_without_bias_first_reduce_unsuitable_axis(%x: tensor<1x384x // CHECK-DAG: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none // CHECK-DAG: [[VAR_1_:%.+]] = "onnx.ReduceMeanV13"([[PARAM_0_]]) {axes = [-2], keepdims = 1 : si64} : (tensor<1x384x768xf32>) -> tensor<1x384x1xf32> // CHECK: [[VAR_2_:%.+]] = "onnx.Sub"([[PARAM_0_]], [[VAR_1_]]) : (tensor<1x384x768xf32>, tensor<1x384x1xf32>) -> tensor<1x384x768xf32> -// CHECK: [[Y_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.RMSLayerNormalization"([[VAR_2_]], [[PARAM_1_]], [[VAR_0_]]) {axis = 2 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, none) -> (tensor<1x384x768xf32>, none) +// CHECK: [[Y_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.RMSLayerNormalization"([[VAR_2_]], [[PARAM_1_]], [[VAR_0_]]) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, none) -> (tensor<1x384x768xf32>, none) // CHECK: return [[Y_]] : tensor<1x384x768xf32> // CHECK: } } @@ -79,7 +79,7 @@ func.func @layernorm_without_bias_first_reduce_unsuitable_axis(%x: tensor<1x384x // ----- func.func @layernorm_without_bias_second_reduce_unsuitable_axis(%x: tensor<1x384x768xf32>, %scale: tensor<768xf32>, %bias: tensor<768xf32>) -> (tensor<1x384x768xf32>) { - %eps = onnx.Constant dense<9.99999974E-6> : tensor + %eps = onnx.Constant dense<1.2E+0> : tensor %mean = "onnx.ReduceMeanV13"(%x) {axes = [-1], keepdims = 1 : si64} : (tensor<1x384x768xf32>) -> tensor<1x384x1xf32> %d = "onnx.Sub"(%x, %mean) : (tensor<1x384x768xf32>, tensor<1x384x1xf32>) -> tensor<1x384x768xf32> %dd = "onnx.Mul"(%d, %d) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> @@ -92,7 +92,7 @@ func.func @layernorm_without_bias_second_reduce_unsuitable_axis(%x: tensor<1x384 // mlir2FileCheck.py // CHECK-LABEL: func.func @layernorm_without_bias_second_reduce_unsuitable_axis // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>, [[PARAM_2_:%.+]]: tensor<768xf32>) -> tensor<1x384x768xf32> { -// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<9.99999974E-6> : tensor +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<1.200000e+00> : tensor // CHECK-DAG: [[VAR_1_:%.+]] = "onnx.ReduceMeanV13"([[PARAM_0_]]) {axes = [-1], keepdims = 1 : si64} : (tensor<1x384x768xf32>) -> tensor<1x384x1xf32> // CHECK: [[VAR_2_:%.+]] = "onnx.Sub"([[PARAM_0_]], [[VAR_1_]]) : (tensor<1x384x768xf32>, tensor<1x384x1xf32>) -> tensor<1x384x768xf32> // CHECK: [[VAR_3_:%.+]] = "onnx.Mul"([[VAR_2_]], [[VAR_2_]]) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> @@ -108,7 +108,7 @@ func.func @layernorm_without_bias_second_reduce_unsuitable_axis(%x: tensor<1x384 // ----- func.func @layernorm_without_bias_v18(%x: tensor<1x384x768xf32>, %scale: tensor<768xf32>, %bias: tensor<768xf32>) -> (tensor<1x384x768xf32>) { - %eps = onnx.Constant dense<9.99999974E-6> : tensor + %eps = onnx.Constant dense<1.2E+0> : tensor %axis = onnx.Constant dense<-1> : tensor<1xi64> %mean = "onnx.ReduceMean"(%x, %axis) {keepdims = 1 : si64} : (tensor<1x384x768xf32>, tensor<1xi64>) -> tensor<1x384x1xf32> %d = "onnx.Sub"(%x, %mean) : (tensor<1x384x768xf32>, tensor<1x384x1xf32>) -> tensor<1x384x768xf32> @@ -123,7 +123,7 @@ func.func @layernorm_without_bias_v18(%x: tensor<1x384x768xf32>, %scale: tensor< // CHECK-LABEL: func.func @layernorm_without_bias_v18 // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>, [[PARAM_2_:%.+]]: tensor<768xf32>) -> tensor<1x384x768xf32> { // CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK: [[Y_:%.+]], [[Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[PARAM_0_]], [[PARAM_1_]], [[VAR_0_]]) {axis = 2 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, none) -> (tensor<1x384x768xf32>, none, none) +// CHECK: [[Y_:%.+]], [[Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[PARAM_0_]], [[PARAM_1_]], [[VAR_0_]]) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, none) -> (tensor<1x384x768xf32>, none, none) // CHECK: return [[Y_]] : tensor<1x384x768xf32> // CHECK: } } @@ -131,7 +131,7 @@ func.func @layernorm_without_bias_v18(%x: tensor<1x384x768xf32>, %scale: tensor< // ----- func.func @layernorm_without_bias_v18_dynamic_axis(%x: tensor<1x384x768xf32>, %scale: tensor<768xf32>, %bias: tensor<768xf32>, %axis: tensor) -> (tensor<1x384x768xf32>) { - %eps = onnx.Constant dense<9.99999974E-6> : tensor + %eps = onnx.Constant dense<1.2E+0> : tensor %mean = "onnx.ReduceMean"(%x, %axis) {keepdims = 1 : si64} : (tensor<1x384x768xf32>, tensor) -> tensor<1x384x1xf32> %d = "onnx.Sub"(%x, %mean) : (tensor<1x384x768xf32>, tensor<1x384x1xf32>) -> tensor<1x384x768xf32> %dd = "onnx.Mul"(%d, %d) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> @@ -144,7 +144,7 @@ func.func @layernorm_without_bias_v18_dynamic_axis(%x: tensor<1x384x768xf32>, %s // mlir2FileCheck.py // CHECK-LABEL: func.func @layernorm_without_bias_v18_dynamic_axis // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>, [[PARAM_2_:%.+]]: tensor<768xf32>, [[PARAM_3_:%.+]]: tensor) -> tensor<1x384x768xf32> { -// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<9.99999974E-6> : tensor +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<1.200000e+00> : tensor // CHECK-DAG: [[VAR_1_:%.+]] = "onnx.ReduceMean"([[PARAM_0_]], [[PARAM_3_]]) {keepdims = 1 : si64, noop_with_empty_axes = 0 : si64} : (tensor<1x384x768xf32>, tensor) -> tensor<1x384x1xf32> // CHECK: [[VAR_2_:%.+]] = "onnx.Sub"([[PARAM_0_]], [[VAR_1_]]) : (tensor<1x384x768xf32>, tensor<1x384x1xf32>) -> tensor<1x384x768xf32> // CHECK: [[VAR_3_:%.+]] = "onnx.Mul"([[VAR_2_]], [[VAR_2_]]) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> @@ -160,7 +160,7 @@ func.func @layernorm_without_bias_v18_dynamic_axis(%x: tensor<1x384x768xf32>, %s // ----- func.func @layernorm_without_bias_first_reduce_unsuitable_axis_v18(%x: tensor<1x384x768xf32>, %scale: tensor<768xf32>, %bias: tensor<768xf32>) -> (tensor<1x384x768xf32>) { - %eps = onnx.Constant dense<9.99999974E-6> : tensor + %eps = onnx.Constant dense<1.2E+0> : tensor %axis1 = onnx.Constant dense<-2> : tensor<1xi64> %axis2 = onnx.Constant dense<-1> : tensor<1xi64> %mean = "onnx.ReduceMean"(%x, %axis1) {keepdims = 1 : si64} : (tensor<1x384x768xf32>, tensor<1xi64>) -> tensor<1x384x1xf32> @@ -179,7 +179,7 @@ func.func @layernorm_without_bias_first_reduce_unsuitable_axis_v18(%x: tensor<1x // CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<-2> : tensor<1xi64> // CHECK: [[VAR_2_:%.+]] = "onnx.ReduceMean"([[PARAM_0_]], [[VAR_1_]]) {keepdims = 1 : si64, noop_with_empty_axes = 0 : si64} : (tensor<1x384x768xf32>, tensor<1xi64>) -> tensor<1x384x1xf32> // CHECK: [[VAR_3_:%.+]] = "onnx.Sub"([[PARAM_0_]], [[VAR_2_]]) : (tensor<1x384x768xf32>, tensor<1x384x1xf32>) -> tensor<1x384x768xf32> -// CHECK: [[Y_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.RMSLayerNormalization"([[VAR_3_]], [[PARAM_1_]], [[VAR_0_]]) {axis = 2 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, none) -> (tensor<1x384x768xf32>, none) +// CHECK: [[Y_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.RMSLayerNormalization"([[VAR_3_]], [[PARAM_1_]], [[VAR_0_]]) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, none) -> (tensor<1x384x768xf32>, none) // CHECK: return [[Y_]] : tensor<1x384x768xf32> // CHECK: } } @@ -187,7 +187,7 @@ func.func @layernorm_without_bias_first_reduce_unsuitable_axis_v18(%x: tensor<1x // ----- func.func @layernorm_without_bias_second_reduce_unsuitable_axis_v18(%x: tensor<1x384x768xf32>, %scale: tensor<768xf32>, %bias: tensor<768xf32>) -> (tensor<1x384x768xf32>) { - %eps = onnx.Constant dense<9.99999974E-6> : tensor + %eps = onnx.Constant dense<1.2E+0> : tensor %axis1 = onnx.Constant dense<-1> : tensor<1xi64> %axis2 = onnx.Constant dense<-2> : tensor<1xi64> %mean = "onnx.ReduceMean"(%x, %axis1) {keepdims = 1 : si64} : (tensor<1x384x768xf32>, tensor<1xi64>) -> tensor<1x384x1xf32> @@ -202,7 +202,7 @@ func.func @layernorm_without_bias_second_reduce_unsuitable_axis_v18(%x: tensor<1 // mlir2FileCheck.py // CHECK-LABEL: func.func @layernorm_without_bias_second_reduce_unsuitable_axis_v18 // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>, [[PARAM_2_:%.+]]: tensor<768xf32>) -> tensor<1x384x768xf32> { -// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<9.99999974E-6> : tensor +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<1.200000e+00> : tensor // CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<-1> : tensor<1xi64> // CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<-2> : tensor<1xi64> // CHECK: [[VAR_3_:%.+]] = "onnx.ReduceMean"([[PARAM_0_]], [[VAR_1_]]) {keepdims = 1 : si64, noop_with_empty_axes = 0 : si64} : (tensor<1x384x768xf32>, tensor<1xi64>) -> tensor<1x384x1xf32> @@ -220,7 +220,7 @@ func.func @layernorm_without_bias_second_reduce_unsuitable_axis_v18(%x: tensor<1 // ----- func.func @layernorm_without_bias_v18_noop(%x: tensor<1x384x768xf32>, %scale: tensor<768xf32>, %bias: tensor<768xf32>) -> (tensor<1x384x768xf32>) { - %eps = onnx.Constant dense<9.99999974E-6> : tensor + %eps = onnx.Constant dense<1.2E+0> : tensor %none = "onnx.NoValue"() {value} : () -> none %mean = "onnx.ReduceMean"(%x, %none) {keepdims = 1 : si64, noop_with_empty_axes = 1: si64} : (tensor<1x384x768xf32>, none) -> tensor<1x384x768xf32> %d = "onnx.Sub"(%x, %mean) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> @@ -234,7 +234,7 @@ func.func @layernorm_without_bias_v18_noop(%x: tensor<1x384x768xf32>, %scale: te // mlir2FileCheck.py // CHECK-LABEL: func.func @layernorm_without_bias_v18_noop // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>, [[PARAM_2_:%.+]]: tensor<768xf32>) -> tensor<1x384x768xf32> { -// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<9.99999974E-6> : tensor +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<1.200000e+00> : tensor // CHECK: [[VAR_1_:%.+]] = "onnx.Sub"([[PARAM_0_]], [[PARAM_0_]]) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> // CHECK: [[VAR_2_:%.+]] = "onnx.Mul"([[VAR_1_]], [[VAR_1_]]) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> // CHECK: [[VAR_3_:%.+]] = "onnx.Add"([[VAR_2_]], [[VAR_0_]]) : (tensor<1x384x768xf32>, tensor) -> tensor<1x384x768xf32> @@ -248,7 +248,7 @@ func.func @layernorm_without_bias_v18_noop(%x: tensor<1x384x768xf32>, %scale: te // ----- func.func @layernorm_without_bias_v18_reduce_all(%x: tensor<1x384x768xf32>, %scale: tensor<768xf32>, %bias: tensor<768xf32>) -> (tensor<1x384x768xf32>) { - %eps = onnx.Constant dense<9.99999974E-6> : tensor + %eps = onnx.Constant dense<1.2E+0> : tensor %none = "onnx.NoValue"() {value} : () -> none %mean = "onnx.ReduceMean"(%x, %none) {keepdims = 1 : si64, noop_with_empty_axes = 0: si64} : (tensor<1x384x768xf32>, none) -> tensor<1x384x1xf32> %d = "onnx.Sub"(%x, %mean) : (tensor<1x384x768xf32>, tensor<1x384x1xf32>) -> tensor<1x384x768xf32> @@ -263,7 +263,7 @@ func.func @layernorm_without_bias_v18_reduce_all(%x: tensor<1x384x768xf32>, %sca // CHECK-LABEL: func.func @layernorm_without_bias_v18_reduce_all // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>, [[PARAM_2_:%.+]]: tensor<768xf32>) -> tensor<1x384x768xf32> { // CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK: [[Y_:%.+]], [[Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[PARAM_0_]], [[PARAM_1_]], [[VAR_0_]]) {axis = 0 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, none) -> (tensor<1x384x768xf32>, none, none) +// CHECK: [[Y_:%.+]], [[Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[PARAM_0_]], [[PARAM_1_]], [[VAR_0_]]) {axis = 0 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, none) -> (tensor<1x384x768xf32>, none, none) // CHECK: return [[Y_]] : tensor<1x384x768xf32> // CHECK: } } @@ -273,7 +273,7 @@ func.func @layernorm_without_bias_v18_reduce_all(%x: tensor<1x384x768xf32>, %sca // Layernorm, add/mul switched func.func @layernorm_with_bias_switched(%x: tensor<1x384x768xf32>, %scale: tensor<768xf32>, %bias: tensor<768xf32>) -> (tensor<1x384x768xf32>) { - %eps = onnx.Constant dense<9.99999974E-6> : tensor + %eps = onnx.Constant dense<1.2E+0> : tensor %mean = "onnx.ReduceMeanV13"(%x) {axes = [-1], keepdims = 1 : si64} : (tensor<1x384x768xf32>) -> tensor<1x384x1xf32> %d = "onnx.Sub"(%x, %mean) : (tensor<1x384x768xf32>, tensor<1x384x1xf32>) -> tensor<1x384x768xf32> %dd = "onnx.Mul"(%d, %d) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> @@ -288,7 +288,7 @@ func.func @layernorm_with_bias_switched(%x: tensor<1x384x768xf32>, %scale: tenso // mlir2FileCheck.py // CHECK-LABEL: func.func @layernorm_with_bias_switched // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>, [[PARAM_2_:%.+]]: tensor<768xf32>) -> tensor<1x384x768xf32> { -// CHECK: [[Y_:%.+]], [[Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]]) {axis = 2 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, tensor<768xf32>) -> (tensor<1x384x768xf32>, none, none) +// CHECK: [[Y_:%.+]], [[Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]]) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, tensor<768xf32>) -> (tensor<1x384x768xf32>, none, none) // CHECK: return [[Y_]] : tensor<1x384x768xf32> // CHECK: } } @@ -298,13 +298,13 @@ func.func @layernorm_with_bias_switched(%x: tensor<1x384x768xf32>, %scale: tenso // Recognize the bias and fold into LayerNorm. func.func @layernorm_without_bias(%arg0: tensor<1x384x768xf32>, %arg1: tensor<768xf32>, %bias: tensor<768xf32>) -> tensor<1x384x768xf32> { %0 = "onnx.NoValue"() {value} : () -> none - %NormScaled, %Mean, %InvStdDev = "onnx.LayerNormalization"(%arg0, %arg1, %0) {axis = 2 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, none) -> (tensor<1x384x768xf32>, none, none) + %NormScaled, %Mean, %InvStdDev = "onnx.LayerNormalization"(%arg0, %arg1, %0) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, none) -> (tensor<1x384x768xf32>, none, none) %Y = "onnx.Add"(%bias, %NormScaled) : (tensor<768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> return %Y : tensor<1x384x768xf32> // mlir2FileCheck.py // CHECK-LABEL: func.func @layernorm_without_bias // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>, [[PARAM_2_:%.+]]: tensor<768xf32>) -> tensor<1x384x768xf32> { -// CHECK: [[Y_:%.+]], [[Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]]) {axis = 2 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, tensor<768xf32>) -> (tensor<1x384x768xf32>, none, none) +// CHECK: [[Y_:%.+]], [[Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]]) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, tensor<768xf32>) -> (tensor<1x384x768xf32>, none, none) // CHECK: return [[Y_]] : tensor<1x384x768xf32> // CHECK: } } @@ -313,7 +313,7 @@ func.func @layernorm_without_bias(%arg0: tensor<1x384x768xf32>, %arg1: tensor<76 // Not a Layernorm as top sub has inputs switched func.func @not_a_layer_norm(%x: tensor<1x384x768xf32>, %scale: tensor<768xf32>, %bias: tensor<768xf32>) -> (tensor<1x384x768xf32>) { - %eps = onnx.Constant dense<9.99999974E-6> : tensor + %eps = onnx.Constant dense<1.2E+0> : tensor %mean = "onnx.ReduceMeanV13"(%x) {axes = [-1], keepdims = 1 : si64} : (tensor<1x384x768xf32>) -> tensor<1x384x1xf32> %d = "onnx.Sub"(%mean, %x) : (tensor<1x384x1xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> %dd = "onnx.Mul"(%d, %d) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> @@ -333,7 +333,7 @@ func.func @not_a_layer_norm(%x: tensor<1x384x768xf32>, %scale: tensor<768xf32>, // ----- // Check alternative layer norm with reciprocal instead of div func.func @layer_norm_with_reciprocal(%input: tensor<1x384x768xf32>, %scale: tensor<768xf32>, %bias: tensor<768xf32>) -> (tensor<1x384x768xf32>) { - %eps = onnx.Constant dense<9.99999974E-6> : tensor + %eps = onnx.Constant dense<1.2E+0> : tensor %x = "onnx.Add"(%input, %input) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> %mean = "onnx.ReduceMeanV13"(%x) {axes = [-1], keepdims = 1 : si64} : (tensor<1x384x768xf32>) -> tensor<1x384x1xf32> %d = "onnx.Sub"(%x, %mean) : (tensor<1x384x768xf32>, tensor<1x384x1xf32>) -> tensor<1x384x768xf32> @@ -351,7 +351,7 @@ func.func @layer_norm_with_reciprocal(%input: tensor<1x384x768xf32>, %scale: ten // CHECK-LABEL: func.func @layer_norm_with_reciprocal // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>, [[PARAM_2_:%.+]]: tensor<768xf32>) -> tensor<1x384x768xf32> { // CHECK: [[VAR_0_:%.+]] = "onnx.Add"([[PARAM_0_]], [[PARAM_0_]]) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> -// CHECK: [[Y_:%.+]], [[Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[VAR_0_]], [[PARAM_1_]], [[PARAM_2_]]) {axis = 2 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, tensor<768xf32>) -> (tensor<1x384x768xf32>, none, none) +// CHECK: [[Y_:%.+]], [[Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[VAR_0_]], [[PARAM_1_]], [[PARAM_2_]]) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, tensor<768xf32>) -> (tensor<1x384x768xf32>, none, none) // CHECK: [[VAR_1_:%.+]] = "onnx.Add"([[Y_]], [[PARAM_2_]]) : (tensor<1x384x768xf32>, tensor<768xf32>) -> tensor<1x384x768xf32> // CHECK: return [[VAR_1_]] : tensor<1x384x768xf32> // CHECK: } @@ -361,7 +361,7 @@ func.func @layer_norm_with_reciprocal(%input: tensor<1x384x768xf32>, %scale: ten // Check alternative layer norm with reciprocal instead of div func.func @layer_norm_with_div_by_one(%input: tensor<1x384x768xf32>, %scale: tensor<768xf32>, %bias: tensor<768xf32>) -> (tensor<1x384x768xf32>) { - %eps = onnx.Constant dense<9.99999974E-6> : tensor + %eps = onnx.Constant dense<1.2E+0> : tensor %one = onnx.Constant dense<1.0> : tensor %x = "onnx.Add"(%input, %input) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> %mean = "onnx.ReduceMeanV13"(%x) {axes = [-1], keepdims = 1 : si64} : (tensor<1x384x768xf32>) -> tensor<1x384x1xf32> @@ -380,7 +380,7 @@ func.func @layer_norm_with_div_by_one(%input: tensor<1x384x768xf32>, %scale: ten // CHECK-LABEL: func.func @layer_norm_with_div_by_one // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>, [[PARAM_2_:%.+]]: tensor<768xf32>) -> tensor<1x384x768xf32> { // CHECK: [[VAR_0_:%.+]] = "onnx.Add"([[PARAM_0_]], [[PARAM_0_]]) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> -// CHECK: [[Y_:%.+]], [[Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[VAR_0_]], [[PARAM_1_]], [[PARAM_2_]]) {axis = 2 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, tensor<768xf32>) -> (tensor<1x384x768xf32>, none, none) +// CHECK: [[Y_:%.+]], [[Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[VAR_0_]], [[PARAM_1_]], [[PARAM_2_]]) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, tensor<768xf32>) -> (tensor<1x384x768xf32>, none, none) // CHECK: [[VAR_1_:%.+]] = "onnx.Add"([[Y_]], [[PARAM_2_]]) : (tensor<1x384x768xf32>, tensor<768xf32>) -> tensor<1x384x768xf32> // CHECK: return [[VAR_1_]] : tensor<1x384x768xf32> // CHECK: } @@ -390,7 +390,7 @@ func.func @layer_norm_with_div_by_one(%input: tensor<1x384x768xf32>, %scale: ten // Check alternative layer norm with reciprocal instead of div, fail because it is 2 / x instead of 1 / x func.func @not_a_layer_norm_with_div_by_two(%input: tensor<1x384x768xf32>, %scale: tensor<768xf32>, %bias: tensor<768xf32>) -> (tensor<1x384x768xf32>) { - %eps = onnx.Constant dense<9.99999974E-6> : tensor + %eps = onnx.Constant dense<1.2E+0> : tensor %one = onnx.Constant dense<2.0> : tensor %x = "onnx.Add"(%input, %input) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> %mean = "onnx.ReduceMeanV13"(%x) {axes = [-1], keepdims = 1 : si64} : (tensor<1x384x768xf32>) -> tensor<1x384x1xf32> @@ -416,7 +416,7 @@ func.func @not_a_layer_norm_with_div_by_two(%input: tensor<1x384x768xf32>, %scal // RMS Layer norm (sub switched) func.func @rms_layer_norm_v1(%x: tensor<1x384x768xf32>, %scale: tensor<768xf32>, %bias: tensor<768xf32>) -> (tensor<1x384x768xf32>) { - %eps = onnx.Constant dense<9.99999974E-6> : tensor + %eps = onnx.Constant dense<1.2E+0> : tensor %mean = "onnx.ReduceMeanV13"(%x) {axes = [-1], keepdims = 1 : si64} : (tensor<1x384x768xf32>) -> tensor<1x384x1xf32> %d = "onnx.Sub"(%mean, %x) : (tensor<1x384x1xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> %dd = "onnx.Mul"(%d, %d) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> @@ -433,7 +433,7 @@ func.func @rms_layer_norm_v1(%x: tensor<1x384x768xf32>, %scale: tensor<768xf32>, // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>, [[PARAM_2_:%.+]]: tensor<768xf32>) -> tensor<1x384x768xf32> { // CHECK: [[VAR_0_:%.+]] = "onnx.ReduceMeanV13"([[PARAM_0_]]) {axes = [-1], keepdims = 1 : si64} : (tensor<1x384x768xf32>) -> tensor<1x384x1xf32> // CHECK: [[VAR_1_:%.+]] = "onnx.Sub"([[VAR_0_]], [[PARAM_0_]]) : (tensor<1x384x1xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> -// CHECK: [[Y_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.RMSLayerNormalization"([[VAR_1_]], [[PARAM_1_]], [[PARAM_2_]]) {axis = 2 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, tensor<768xf32>) -> (tensor<1x384x768xf32>, none) +// CHECK: [[Y_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.RMSLayerNormalization"([[VAR_1_]], [[PARAM_1_]], [[PARAM_2_]]) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, tensor<768xf32>) -> (tensor<1x384x768xf32>, none) // CHECK: return [[Y_]] : tensor<1x384x768xf32> // CHECK: } } @@ -443,7 +443,7 @@ func.func @rms_layer_norm_v1(%x: tensor<1x384x768xf32>, %scale: tensor<768xf32>, // RMS Layer norm func.func @rms_layer_norm_v2(%x: tensor<1x384x768xf32>, %scale: tensor<768xf32>, %bias: tensor<768xf32>) -> (tensor<1x384x768xf32>) { - %eps = onnx.Constant dense<9.99999974E-6> : tensor + %eps = onnx.Constant dense<1.2E+0> : tensor %dd = "onnx.Mul"(%x, %x) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> %var = "onnx.ReduceMeanV13"(%dd) {axes = [-1], keepdims = 1 : si64, onnx_node_name = "ReduceMean_42"} : (tensor<1x384x768xf32>) -> tensor<1x384x1xf32> %varEps = "onnx.Add"(%eps, %var) : (tensor, tensor<1x384x1xf32>) -> tensor<1x384x1xf32> @@ -456,7 +456,7 @@ func.func @rms_layer_norm_v2(%x: tensor<1x384x768xf32>, %scale: tensor<768xf32>, // mlir2FileCheck.py // CHECK-LABEL: func.func @rms_layer_norm_v2 // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>, [[PARAM_2_:%.+]]: tensor<768xf32>) -> tensor<1x384x768xf32> { -// CHECK: [[Y_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.RMSLayerNormalization"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]]) {axis = 2 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, tensor<768xf32>) -> (tensor<1x384x768xf32>, none) +// CHECK: [[Y_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.RMSLayerNormalization"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]]) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, tensor<768xf32>) -> (tensor<1x384x768xf32>, none) // CHECK: return [[Y_]] : tensor<1x384x768xf32> // CHECK: } }