From 063fbb027dd68fac01b3bed480b4cf0870c8e460 Mon Sep 17 00:00:00 2001 From: "Rickert, Jonas" Date: Mon, 3 Feb 2025 13:36:46 +0000 Subject: [PATCH] Make ONNX to TOSA lowering more resistant to dynamic shapes and unranked types. - Add checks for types being ranked in various lowerings - Refactor getSplattedConst to only require a type and not a rank - Manually annotate the return type for some tosa ops to prevent type mismatches -- In some cases to result type of an op can not be inferred just from its inputs, in these cases the annotation is always necxessary -- In some cases the TOSA shape inference does not infere types even if it is possible, this is something I want to improve in another PR for MLIR. --- src/Conversion/ONNXToTOSA/DialectBuilder.cpp | 58 ++++--- src/Conversion/ONNXToTOSA/DialectBuilder.hpp | 3 +- .../ONNXToTOSA/Math/Elementwise.cpp | 56 +++++-- src/Conversion/ONNXToTOSA/Math/Gemm.cpp | 8 +- src/Conversion/ONNXToTOSA/Math/Reduce.cpp | 2 +- src/Conversion/ONNXToTOSA/NN/AveragePool.cpp | 2 +- src/Conversion/ONNXToTOSA/NN/BatchNorm.cpp | 7 +- .../ONNXToTOSA/ONNXToTOSACommon.cpp | 2 +- .../ONNXToTOSA/ONNXToTOSALegalizeUtils.cpp | 4 +- .../ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp | 6 +- src/Conversion/ONNXToTOSA/Tensor/Gather.cpp | 4 +- .../ONNXToTOSA/Tensor/PaddingOp.cpp | 6 +- src/Conversion/ONNXToTOSA/Tensor/Shrink.cpp | 10 +- .../onnx_to_tosa/Math/Elementwise.mlir | 151 ++++++++++++++++++ 14 files changed, 256 insertions(+), 63 deletions(-) diff --git a/src/Conversion/ONNXToTOSA/DialectBuilder.cpp b/src/Conversion/ONNXToTOSA/DialectBuilder.cpp index d0caaf4161..a52c54053c 100644 --- a/src/Conversion/ONNXToTOSA/DialectBuilder.cpp +++ b/src/Conversion/ONNXToTOSA/DialectBuilder.cpp @@ -57,6 +57,12 @@ Value TosaBuilder::createConst( } bool TosaBuilder::needsRankBroadcast(ValueRange valueRange) { + if (llvm::any_of(valueRange, [](const auto value) { + return !mlir::cast(value.getType()).hasRank(); + })) { + return false; // we have no way to determine the broadcast, so do not + // attempt it + } int64_t firstRank = mlir::cast(valueRange[0].getType()).getRank(); for (Value operand : valueRange) { auto operandType = mlir::cast(operand.getType()); @@ -129,9 +135,8 @@ Value TosaBuilder::getConst(ArrayRef vec, ArrayRef shape) { return constOp; } -Value TosaBuilder::getSplattedConst( - float val, Type dtype, llvm::ArrayRef shape) { - auto constType = tosa::reduceAxisToOne(shape, rewriter().getF32Type()); +Value TosaBuilder::getSplattedConst(float val, Type dtype, int64_t rank) { + auto constType = tosa::reduceAxisToOne(rank, rewriter().getF32Type()); auto constAttr = DenseElementsAttr::get(constType, val); auto constOp = @@ -150,8 +155,7 @@ Value TosaBuilder::transpose(Value &value, llvm::ArrayRef perm) { auto valueType = mlir::cast(value.getType()); // get new value type Type newValueType = RankedTensorType::get( - llvm::SmallVector( - valueType.getShape().size(), ShapedType::kDynamic), + llvm::SmallVector(perm.size(), ShapedType::kDynamic), valueType.getElementType()); // create transpose for value Value newValue = tosa::CreateOpAndInfer( @@ -195,9 +199,12 @@ Value TosaBuilder::mul(Value &lhs, Value &rhs, int32_t shift) { rhs = valueVec[1]; } auto lhsType = mlir::cast(lhs.getType()); - Type newValueType = RankedTensorType::get( - llvm::SmallVector(lhsType.getRank(), ShapedType::kDynamic), - lhsType.getElementType()); + Type newValueType = + (!lhsType.hasRank()) + ? lhsType + : RankedTensorType::get(llvm::SmallVector( + lhsType.getRank(), ShapedType::kDynamic), + lhsType.getElementType()); return tosa::CreateOpAndInfer( rewriter(), loc(), newValueType, lhs, rhs, shift); } @@ -215,9 +222,12 @@ Value TosaBuilder::intdiv(Value &lhs, Value &rhs) { } auto lhsType = mlir::cast(lhs.getType()); - Type newValueType = RankedTensorType::get( - llvm::SmallVector(lhsType.getRank(), ShapedType::kDynamic), - lhsElementType); + Type newValueType = + (!lhsType.hasRank()) + ? lhsType + : RankedTensorType::get(llvm::SmallVector( + lhsType.getRank(), ShapedType::kDynamic), + lhsElementType); return tosa::CreateOpAndInfer( rewriter(), loc(), newValueType, lhs, rhs); } @@ -230,9 +240,12 @@ Value TosaBuilder::binaryOp(Value &lhs, Value &rhs) { rhs = valueVec[1]; } auto lhsType = mlir::cast(lhs.getType()); - Type newValueType = RankedTensorType::get( - llvm::SmallVector(lhsType.getRank(), ShapedType::kDynamic), - lhsType.getElementType()); + Type newValueType = + (!lhsType.hasRank()) + ? lhsType + : RankedTensorType::get(llvm::SmallVector( + lhsType.getRank(), ShapedType::kDynamic), + lhsType.getElementType()); return tosa::CreateOpAndInfer(rewriter(), loc(), newValueType, lhs, rhs); } @@ -246,11 +259,7 @@ template Value TosaBuilder::binaryOp( template Value TosaBuilder::unaryOp(mlir::Value &input) { - auto inputType = cast(input.getType()); - Type newValueType = RankedTensorType::get( - llvm::SmallVector(inputType.getRank(), ShapedType::kDynamic), - inputType.getElementType()); - return tosa::CreateOpAndInfer(rewriter(), loc(), newValueType, input); + return tosa::CreateOpAndInfer(rewriter(), loc(), input.getType(), input); } template Value TosaBuilder::unaryOp(mlir::Value &input); @@ -305,9 +314,12 @@ Value TosaBuilder::select( rhs = valueVec[2]; } auto lhsType = cast(lhs.getType()); - Type newValueType = RankedTensorType::get( - llvm::SmallVector(lhsType.getRank(), ShapedType::kDynamic), - lhsType.getElementType()); + Type newValueType = + (!lhsType.hasRank()) + ? lhsType + : RankedTensorType::get(llvm::SmallVector( + lhsType.getRank(), ShapedType::kDynamic), + lhsType.getElementType()); return tosa::CreateOpAndInfer( rewriter(), loc(), newValueType, cond, lhs, rhs); } @@ -328,7 +340,7 @@ mlir::Value TosaBuilder::castToNewTensorElementType( Value TosaBuilder::sqrt(mlir::Value &input) { auto inputType = cast(input.getType()); auto oneHalf = this->getSplattedConst( - 0.5, inputType.getElementType(), inputType.getShape()); + 0.5, inputType.getElementType(), inputType.getRank()); return this->binaryOp(input, oneHalf); } diff --git a/src/Conversion/ONNXToTOSA/DialectBuilder.hpp b/src/Conversion/ONNXToTOSA/DialectBuilder.hpp index b97c912692..333f90c456 100644 --- a/src/Conversion/ONNXToTOSA/DialectBuilder.hpp +++ b/src/Conversion/ONNXToTOSA/DialectBuilder.hpp @@ -91,8 +91,7 @@ struct TosaBuilder : DialectBuilder { // The tensor will have the same rank as shape but all dimensions will // have size 1 (differs from tensorflow impl.) // If dtype is provided, it also cast the value to the appropriate dtype. - mlir::Value getSplattedConst( - float val, mlir::Type dtype, llvm::ArrayRef shape = {}); + mlir::Value getSplattedConst(float val, mlir::Type dtype, int64_t rank); // Creates a constant of shape <1x1x...x1> of rank `rank` with all values set // to `value`. diff --git a/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp b/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp index c82657e751..9e259174ca 100644 --- a/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp +++ b/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp @@ -239,12 +239,12 @@ static LogicalResult legalizeFloatingPointPrelu(Operation *op, auto loc = op->getLoc(); TosaBuilder tosaBuilder(rewriter, loc); Value constZero = tosaBuilder.getSplattedConst( - 0.0, outputType.getElementType(), outputType.getShape()); + 0.0, outputType.getElementType(), outputType.getRank()); auto mul = tosaBuilder.mul(input, alphaOrSlope); auto greaterEqual = tosaBuilder.greaterEqual(input, constZero); auto select = tosaBuilder.select(greaterEqual, input, mul); - + copySingleResultType(op, select); rewriter.replaceOp(op, {select}); return success(); } @@ -274,7 +274,7 @@ class ONNXLeakyReluOpLoweringToTOSA TosaBuilder tosaBuilder(rewriter, loc); return legalizeFloatingPointPrelu(op, rewriter, adaptor.getX(), tosaBuilder.getSplattedConst( - alpha, outputType.getElementType(), outputType.getShape()), + alpha, outputType.getElementType(), outputType.getRank()), outputType); } }; @@ -312,6 +312,7 @@ class ONNXComparisonOpLoweringToTOSA : public OpConversionPattern { } else if constexpr (std::is_same_v) { res = tosaBuilder.less(input1, input2); } + copySingleResultType(op, res); rewriter.replaceOp(op, {res}); return success(); } @@ -393,7 +394,7 @@ class ONNXCastOpLoweringToTOSA : public OpConversionPattern { // onnx.Cast and tosa.cast. if (resultTy.getElementType().getIntOrFloatBitWidth() != 1) { auto zero = tosaBuilder.getSplattedConst( - 0.0f, inputTy.getElementType(), resultTy.getShape()); + 0.0f, inputTy.getElementType(), resultTy.getRank()); auto positive = tosaBuilder.greaterEqual(input, zero); auto floor = tosaBuilder.unaryOp(input); @@ -421,6 +422,7 @@ class ONNXDivOpLoweringToTOSA : public OpConversionPattern { if (isa(resultElementType)) { Value divOp = tosaBuilder.intdiv(lhs, rhs); + copySingleResultType(op, divOp); rewriter.replaceOp(op, {divOp}); return success(); } @@ -428,6 +430,7 @@ class ONNXDivOpLoweringToTOSA : public OpConversionPattern { // tosa::ReciprocalOp and tosa::MulOp. Value reciprocalOp = tosaBuilder.unaryOp(rhs); Value mulOp = tosaBuilder.mul(lhs, reciprocalOp); + copySingleResultType(op, mulOp); rewriter.replaceOp(op, {mulOp}); return success(); } @@ -472,20 +475,21 @@ class ONNXEluOpLoweringToTOSA : public OpConversionPattern { TosaBuilder tosaBuilder(rewriter, op->getLoc()); Value one = tosaBuilder.getSplattedConst( - 1.0, resultTensorType.getElementType(), resultTensorType.getShape()); + 1.0, resultTensorType.getElementType(), resultTensorType.getRank()); Value alpha = tosaBuilder.getSplattedConst(adaptor.getAlpha().convertToDouble(), - resultTensorType.getElementType(), resultTensorType.getShape()); + resultTensorType.getElementType(), resultTensorType.getRank()); Value constZero = tosaBuilder.getSplattedConst( - 0.0, resultTensorType.getElementType(), resultTensorType.getShape()); + 0.0, resultTensorType.getElementType(), resultTensorType.getRank()); Value exp = tosaBuilder.unaryOp(input); + copySingleResultType(op, exp); Value expMinusOne = tosaBuilder.binaryOp(exp, one); Value alphaTimesExpMinusOne = tosaBuilder.mul(expMinusOne, alpha); Value greaterEqual = tosaBuilder.greaterEqual(input, constZero); auto select = tosaBuilder.select(greaterEqual, input, alphaTimesExpMinusOne); - + copySingleResultType(op, select); rewriter.replaceOp(op, {select}); return success(); } @@ -516,11 +520,16 @@ class ONNXHardSigmoidOpLoweringToTOSA APFloat oneOverAlpha(alpha.getSemantics(), 1); oneOverAlpha.divide(alpha, APFloat::rmNearestTiesToEven); + if (!resultType.hasRank()) { + return rewriter.notifyMatchFailure( + op, "HardSigmoid: Static shape required to create splatted const"); + } + Value constBetaOverAlpha = tosaBuilder.getSplattedConst(betaOverAlpha.convertToDouble(), - resultElementType, resultType.getShape()); + resultElementType, resultType.getRank()); Value constAlpha = tosaBuilder.getSplattedConst( - alpha.convertToDouble(), resultElementType, resultType.getShape()); + alpha.convertToDouble(), resultElementType, resultType.getRank()); auto addOp = tosaBuilder.binaryOp(input, constBetaOverAlpha); @@ -530,7 +539,7 @@ class ONNXHardSigmoidOpLoweringToTOSA rewriter.getF32FloatAttr(0), rewriter.getF32FloatAttr(oneOverAlpha.convertToDouble())); auto mulOp = tosaBuilder.mul(clampOp, constAlpha); - + copySingleResultType(op, mulOp); rewriter.replaceOp(op, {mulOp}); return success(); } @@ -565,14 +574,19 @@ class ONNXSoftplusOpLoweringToTOSA if (failed(IsFloat::checkType(rewriter, outputType.getElementType(), op))) { return failure(); } + if (!outputType.hasRank()) { + return rewriter.notifyMatchFailure( + op, "ONNXSoftplusOp: Rank required to create splatted const"); + } Value input = adaptor.getX(); TosaBuilder tosaBuilder(rewriter, op->getLoc()); auto one = tosaBuilder.getSplattedConst( - 1.0, outputType.getElementType(), outputType.getShape()); + 1.0, outputType.getElementType(), outputType.getRank()); auto expOp = tosaBuilder.unaryOp(input); + copySingleResultType(op, expOp); auto expPlusOne = tosaBuilder.binaryOp(expOp, one); auto logOp = tosaBuilder.unaryOp(expPlusOne); rewriter.replaceOp(op, {logOp}); @@ -594,15 +608,19 @@ class ONNXSeluOpLoweringToTOSA : public OpConversionPattern { Value input = adaptor.getX(); TosaBuilder tosaBuilder(rewriter, op->getLoc()); + if (!outputType.hasRank()) { + return rewriter.notifyMatchFailure( + op, "ONNXSeluOp: Rank required to create splatted const"); + } Value alpha = tosaBuilder.getSplattedConst(adaptor.getAlpha().convertToDouble(), - outputType.getElementType(), outputType.getShape()); + outputType.getElementType(), outputType.getRank()); Value gamma = tosaBuilder.getSplattedConst(adaptor.getGamma().convertToDouble(), - outputType.getElementType(), outputType.getShape()); + outputType.getElementType(), outputType.getRank()); Value constZero = tosaBuilder.getSplattedConst( - 0.0, outputType.getElementType(), outputType.getShape()); + 0.0, outputType.getElementType(), outputType.getRank()); Value exp = tosaBuilder.unaryOp(input); Value expTimesAlpha = tosaBuilder.mul(exp, alpha); @@ -630,15 +648,19 @@ class ONNXThresholdedReluOpLoweringToTOSA rewriter, outputType.getElementType(), op))) { return failure(); } + if (!outputType.hasRank()) { + return rewriter.notifyMatchFailure( + op, "ONNXThresholdedReluOp: Rank required to create splatted const"); + } Value input = adaptor.getX(); TosaBuilder tosaBuilder(rewriter, op->getLoc()); auto alpha = tosaBuilder.getSplattedConst(adaptor.getAlpha().convertToDouble(), - outputType.getElementType(), outputType.getShape()); + outputType.getElementType(), outputType.getRank()); auto zero = tosaBuilder.getSplattedConst( - 0.0, outputType.getElementType(), outputType.getShape()); + 0.0, outputType.getElementType(), outputType.getRank()); auto greater = tosaBuilder.greater(input, alpha); auto select = tosaBuilder.select(greater, input, zero); diff --git a/src/Conversion/ONNXToTOSA/Math/Gemm.cpp b/src/Conversion/ONNXToTOSA/Math/Gemm.cpp index dbc8b30411..1d1182d03b 100644 --- a/src/Conversion/ONNXToTOSA/Math/Gemm.cpp +++ b/src/Conversion/ONNXToTOSA/Math/Gemm.cpp @@ -49,6 +49,10 @@ class ONNXGemmOpLoweringToTOSA : public OpConversionPattern { FloatAttr beta = adaptor.getBetaAttr(); auto AType = mlir::cast(A.getType()); auto BType = mlir::cast(B.getType()); + if (!AType.hasRank() || !BType.hasRank()) { + return rewriter.notifyMatchFailure( + op, "Lowering Gemm to MatMul requires ranked A and B."); + } auto shapeA = AType.getShape(); auto shapeB = BType.getShape(); auto resultType = mlir::cast( @@ -103,7 +107,7 @@ class ONNXGemmOpLoweringToTOSA : public OpConversionPattern { if (alpha && alpha.getValueAsDouble() != 1.) { Value splattedConstAlpha = tosaBuilder.getSplattedConst( static_cast(alpha.getValueAsDouble()), AType.getElementType(), - newShapeA); + newShapeA.size()); alphaMulResult = tosaBuilder.mul(splattedConstAlpha, A, 0); } @@ -112,7 +116,7 @@ class ONNXGemmOpLoweringToTOSA : public OpConversionPattern { if (beta && isCPresent && beta.getValueAsDouble() != 1.) { Value splattedConstBeta = tosaBuilder.getSplattedConst( static_cast(beta.getValueAsDouble()), AType.getElementType(), - newShapeA); + newShapeA.size()); betaMulResult = tosaBuilder.mul(splattedConstBeta, C, 0); } diff --git a/src/Conversion/ONNXToTOSA/Math/Reduce.cpp b/src/Conversion/ONNXToTOSA/Math/Reduce.cpp index bb567b771d..72ce004df9 100644 --- a/src/Conversion/ONNXToTOSA/Math/Reduce.cpp +++ b/src/Conversion/ONNXToTOSA/Math/Reduce.cpp @@ -179,7 +179,7 @@ LogicalResult reduceMeanLowering(ONNXReduceMeanOp op, TosaBuilder tosaBuilder(rewriter, op->getLoc()); Value divConst = tosaBuilder.getSplattedConst( - divScale, outputType.getElementType(), outputType.getShape()); + divScale, outputType.getElementType(), outputType.getRank()); auto output = tosaBuilder.mul(val, divConst); if (!output) { diff --git a/src/Conversion/ONNXToTOSA/NN/AveragePool.cpp b/src/Conversion/ONNXToTOSA/NN/AveragePool.cpp index 73ff2f898c..80f56bfb55 100644 --- a/src/Conversion/ONNXToTOSA/NN/AveragePool.cpp +++ b/src/Conversion/ONNXToTOSA/NN/AveragePool.cpp @@ -56,7 +56,7 @@ LogicalResult handleIncludePadAttr( Value padding = tosa::buildOnnxToTosaPaddingConstOp( rewriter, pads, loc, {0, 0, 0, 0}, {}); auto constTosaTensor = - tosaBuilder.getSplattedConst(0.0, inputType.getElementType()); + tosaBuilder.getSplattedConst(0.0, inputType.getElementType(), 0); auto padOp = tosa::CreateOpAndInfer(rewriter, loc, mlir::RankedTensorType::get( diff --git a/src/Conversion/ONNXToTOSA/NN/BatchNorm.cpp b/src/Conversion/ONNXToTOSA/NN/BatchNorm.cpp index 396637d4cc..3cd31c4ebf 100644 --- a/src/Conversion/ONNXToTOSA/NN/BatchNorm.cpp +++ b/src/Conversion/ONNXToTOSA/NN/BatchNorm.cpp @@ -29,6 +29,11 @@ class ONNXBatchNormalizationInferenceModeOpLoweringToTOSA OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto outType = getTypeConverter()->convertType(op.getResult().getType()); + if (!cast(outType).hasRank()) { + return rewriter.notifyMatchFailure(op, + "ONNXBatchNormalizationInferenceModeOp to " + "TOSA requires a ranked result type"); + } auto outTensorType = cast(outType); // The layout of the output is N x C x D1 x D2 … Dn. For batch @@ -60,7 +65,7 @@ class ONNXBatchNormalizationInferenceModeOpLoweringToTOSA // epsilon's shape: constant -> {1, 1, 1, ...} newShape[1] = 1; auto eps = tosaBuilder.getSplattedConst(op.getEpsilon().convertToFloat(), - outTensorType.getElementType(), newShape); + outTensorType.getElementType(), newShape.size()); // output = (input - mean) * scale * rsqrt(var + eps) + bias auto op1SubInputMean = diff --git a/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.cpp b/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.cpp index 1d87cc3dca..fd4c29393f 100644 --- a/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.cpp +++ b/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.cpp @@ -418,7 +418,7 @@ std::optional convertReduceMeanOp(PatternRewriter &rewriter, if (!input_is_qtype) { Value div_const = tosaBuilder.getSplattedConst( - div_scale, output_type.getElementType(), output_type.getShape()); + div_scale, output_type.getElementType(), output_type.getRank()); return tosaBuilder.mul(val.value(), div_const); } diff --git a/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.cpp b/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.cpp index 3a5177234c..24ca82fcac 100644 --- a/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.cpp +++ b/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.cpp @@ -56,9 +56,9 @@ llvm::SmallVector createInt64VectorFromIndexExpr( } mlir::RankedTensorType reduceAxisToOne( - llvm::ArrayRef shape, Type elementType, Attribute encoding) { + int64_t rank, Type elementType, Attribute encoding) { return mlir::RankedTensorType::get( - llvm::SmallVector(shape.size(), 1), elementType, encoding); + llvm::SmallVector(rank, 1), elementType, encoding); } mlir::ElementsAttr getElementsAttrFromConst(mlir::Value &val) { diff --git a/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp b/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp index 8f0896c9c3..5c99af6d5d 100644 --- a/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp +++ b/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp @@ -39,9 +39,9 @@ int64_t convertNegativeAxis(int64_t axis, int64_t inputRank); llvm::SmallVector createInt64VectorFromIndexExpr( llvm::ArrayRef indexVector); -// Create a RankedTensorType with shape and all elements being 1 -mlir::RankedTensorType reduceAxisToOne(llvm::ArrayRef shape, - mlir::Type elementType, mlir::Attribute encoding = {}); +// Create a RankedTensorType with the given rank and all dims being 1 +mlir::RankedTensorType reduceAxisToOne( + int64_t rank, mlir::Type elementType, mlir::Attribute encoding = {}); // Returns the value TOSA ConstOp template diff --git a/src/Conversion/ONNXToTOSA/Tensor/Gather.cpp b/src/Conversion/ONNXToTOSA/Tensor/Gather.cpp index c34525872a..e9e74e838e 100644 --- a/src/Conversion/ONNXToTOSA/Tensor/Gather.cpp +++ b/src/Conversion/ONNXToTOSA/Tensor/Gather.cpp @@ -85,12 +85,12 @@ class ONNXGatherLoweringToTOSA : public OpConversionPattern { // Create an 1x..x1 constant containing the size of the gathered dimension. auto dimSize = tosaBuilder.getSplattedConst( - inputShape[axis], indicesType.getElementType(), indicesType.getShape()); + inputShape[axis], indicesType.getElementType(), indicesType.getRank()); auto indicesPlusDimSize = tosaBuilder.binaryOp(indices, dimSize); auto zero = tosaBuilder.getSplattedConst( - (int64_t)0, indicesType.getElementType(), indicesType.getShape()); + (int64_t)0, indicesType.getElementType(), indicesType.getRank()); auto indicesPositive = tosaBuilder.greaterEqual(indices, zero); auto newIndices = diff --git a/src/Conversion/ONNXToTOSA/Tensor/PaddingOp.cpp b/src/Conversion/ONNXToTOSA/Tensor/PaddingOp.cpp index 41f5f5b02b..058cc43bad 100644 --- a/src/Conversion/ONNXToTOSA/Tensor/PaddingOp.cpp +++ b/src/Conversion/ONNXToTOSA/Tensor/PaddingOp.cpp @@ -98,7 +98,7 @@ class ONNXPadOpLoweringToTOSA : public OpConversionPattern { auto valueIt = valueAttr.getValues().begin(); const float valueFloat = cast(*valueIt).getValueAsDouble(); constTosaTensor = tosaBuilder.getSplattedConst( - valueFloat, valueAttr.getElementType()); + valueFloat, valueAttr.getElementType(), 0); } else { assert(isTOSAInt(elementDtype) && "Already validated"); auto valueIt = valueAttr.getValues().begin(); @@ -106,10 +106,10 @@ class ONNXPadOpLoweringToTOSA : public OpConversionPattern { auto asIntegerTy = cast(valueAttr.getElementType()); if (asIntegerTy.isUnsigned()) { constTosaTensor = tosaBuilder.getSplattedConst( - valueAsAPInt.getZExtValue(), asIntegerTy); + valueAsAPInt.getZExtValue(), asIntegerTy, 0); } else { constTosaTensor = tosaBuilder.getSplattedConst( - valueAsAPInt.getSExtValue(), asIntegerTy); + valueAsAPInt.getSExtValue(), asIntegerTy, 0); } } rewriter.replaceOpWithNewOp( diff --git a/src/Conversion/ONNXToTOSA/Tensor/Shrink.cpp b/src/Conversion/ONNXToTOSA/Tensor/Shrink.cpp index d1969073c1..eb1d78926d 100644 --- a/src/Conversion/ONNXToTOSA/Tensor/Shrink.cpp +++ b/src/Conversion/ONNXToTOSA/Tensor/Shrink.cpp @@ -48,13 +48,13 @@ class ONNXShrinkOpLoweringToTOSA : public OpConversionPattern { const float lambdAsFloat = lambd.getValue().convertToFloat(); const float biasAsFloat = bias.getValue().convertToFloat(); auto lambdConstOp = tosaBuilder.getSplattedConst(lambdAsFloat, - inputRankedTensorTy.getElementType(), inputRankedTensorTy.getShape()); + inputRankedTensorTy.getElementType(), inputRankedTensorTy.getRank()); auto negatedLambdConstOp = tosaBuilder.getSplattedConst(-lambdAsFloat, - inputRankedTensorTy.getElementType(), inputRankedTensorTy.getShape()); + inputRankedTensorTy.getElementType(), inputRankedTensorTy.getRank()); auto biasConstOp = tosaBuilder.getSplattedConst(biasAsFloat, - inputRankedTensorTy.getElementType(), inputRankedTensorTy.getShape()); - auto zeroConstOp = tosaBuilder.getSplattedConst(0, - inputRankedTensorTy.getElementType(), inputRankedTensorTy.getShape()); + inputRankedTensorTy.getElementType(), inputRankedTensorTy.getRank()); + auto zeroConstOp = tosaBuilder.getSplattedConst( + 0, inputRankedTensorTy.getElementType(), inputRankedTensorTy.getRank()); // Formula to be implemented: // { x < -lambd, then y = x + bias diff --git a/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir b/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir index ea0a29bf43..2e46c13420 100644 --- a/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir +++ b/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir @@ -245,6 +245,29 @@ func.func @test_div(%arg0: tensor<13x21x1xi32>, %arg1: tensor<13x21x1xi32>) -> t // ----- +func.func @test_div_dynamic(%arg0: tensor, %arg1: tensor<13x?x?xi32>) -> tensor<13x?x?xi32> { + %0 = "onnx.Div"(%arg0, %arg1) : (tensor, tensor<13x?x?xi32>) -> tensor<13x?x?xi32> + "func.return"(%0) : (tensor<13x?x?xi32>) -> () +// CHECK-LABEL: func @test_div_dynamic +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor<13x?x?xi32>) -> tensor<13x?x?xi32> { +// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.int_div [[PARAM_0_]], [[PARAM_1_]] : (tensor, tensor<13x?x?xi32>) -> tensor<13x?x?xi32> +} + +// ----- + +func.func @test_div_dynamic_float(%arg0: tensor, %arg1: tensor<13x?x?xf32>) -> tensor<13x?x?xf32> { + %0 = "onnx.Div"(%arg0, %arg1) : (tensor, tensor<13x?x?xf32>) -> tensor<13x?x?xf32> + "func.return"(%0) : (tensor<13x?x?xf32>) -> () +// CHECK-LABEL: func.func @test_div_dynamic_float +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor<13x?x?xf32>) -> tensor<13x?x?xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.reciprocal [[PARAM_1_]] : (tensor<13x?x?xf32>) -> tensor<13x?x?xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_0_]] {shift = 0 : i8} : (tensor, tensor<13x?x?xf32>) -> tensor<13x?x?xf32> +// CHECK: return [[VAR_1_]] : tensor<13x?x?xf32> +// CHECK: } +} + +// ----- + func.func @test_div_unsigned(%arg0: tensor<13x21x1xui8>, %arg1: tensor<13x21x1xui8>) -> tensor<13x21x1xui8> { %0 = "onnx.Div"(%arg0, %arg1) : (tensor<13x21x1xui8>, tensor<13x21x1xui8>) -> tensor<13x21x1xui8> "func.return"(%0) : (tensor<13x21x1xui8>) -> () @@ -357,6 +380,37 @@ func.func @test_selu(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { // ----- +func.func @test_selu_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %0 = "onnx.Selu"(%arg0) {alpha = 1.5 : f32, gamma = 2.0 : f32} : (tensor<*xf32>) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +// CHECK-LABEL: func.func @test_selu_unranked +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: [[VAR_0_:%.+]] = "onnx.Selu"([[PARAM_0_]]) {alpha = 1.500000e+00 : f32, gamma = 2.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return [[VAR_0_]] : tensor<*xf32> +// CHECK: } +} + +// ----- + +func.func @test_selu_dynamic(%arg0: tensor) -> tensor { + %0 = "onnx.Selu"(%arg0) {alpha = 1.5 : f32, gamma = 2.0 : f32} : (tensor) -> tensor + func.return %0 : tensor +// CHECK-LABEL: func.func @test_selu_dynamic +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { +// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<1.500000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> +// CHECK-DAG: [[VAR_3_:%.+]] = tosa.exp [[PARAM_0_]] : (tensor) -> tensor +// CHECK: [[VAR_4_:%.+]] = tosa.mul [[VAR_3_]], [[VAR_0_]] {shift = 0 : i8} : (tensor, tensor<1x1x1xf32>) -> tensor +// CHECK-DAG: [[VAR_5_:%.+]] = tosa.sub [[VAR_4_]], [[VAR_0_]] : (tensor, tensor<1x1x1xf32>) -> tensor +// CHECK-DAG: [[VAR_6_:%.+]] = tosa.greater [[PARAM_0_]], [[VAR_2_]] : (tensor, tensor<1x1x1xf32>) -> tensor +// CHECK: [[VAR_7_:%.+]] = tosa.select [[VAR_6_]], [[PARAM_0_]], [[VAR_5_]] : (tensor, tensor, tensor) -> tensor +// CHECK: [[VAR_8_:%.+]] = tosa.mul [[VAR_7_]], [[VAR_1_]] {shift = 0 : i8} : (tensor, tensor<1x1x1xf32>) -> tensor +// CHECK: return [[VAR_8_]] : tensor +} + +// ----- + func.func @test_softplus(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { %0 = "onnx.Softplus"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> func.return %0 : tensor<13x21x3xf32> @@ -368,6 +422,32 @@ func.func @test_softplus(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { // CHECK: return [[VAR_3_]] : tensor<13x21x3xf32> } +// ----- + +func.func @test_softplus_dynamic(%arg0: tensor) -> tensor { + %0 = "onnx.Softplus"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +// CHECK-LABEL: func.func @test_softplus_dynamic +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { +// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.exp [[PARAM_0_]] : (tensor) -> tensor +// CHECK: [[VAR_2_:%.+]] = tosa.add [[VAR_1_]], [[VAR_0_]] : (tensor, tensor<1x1x1xf32>) -> tensor +// CHECK: [[VAR_3_:%.+]] = tosa.log [[VAR_2_]] : (tensor) -> tensor +// CHECK: return [[VAR_3_]] : tensor +} + +// ----- + +func.func @test_softplus_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %0 = "onnx.Softplus"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +// CHECK-LABEL: func.func @test_softplus_unranked +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: [[VAR_0_:%.+]] = "onnx.Softplus"([[PARAM_0_]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return [[VAR_0_]] : tensor<*xf32> +} + + // ----- func.func @test_thresholdedrelu(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { @@ -396,6 +476,31 @@ func.func @test_thresholdedrelu_default_value(%arg0: tensor<13x21x3xf32>) -> ten // ----- +func.func @test_thresholded_relu_dynamic(%arg0: tensor) -> tensor { + %0 = "onnx.ThresholdedRelu"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +// CHECK-LABEL: func.func @test_thresholded_relu_dynamic +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { +// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> +// CHECK: [[VAR_2_:%.+]] = tosa.greater [[PARAM_0_]], [[VAR_0_]] : (tensor, tensor<1x1x1xf32>) -> tensor +// CHECK: [[VAR_3_:%.+]] = tosa.select [[VAR_2_]], [[PARAM_0_]], [[VAR_1_]] : (tensor, tensor, tensor<1x1x1xf32>) -> tensor +// CHECK: return [[VAR_3_]] : tensor +} + +// ----- + +func.func @test_thresholded_relu_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %0 = "onnx.ThresholdedRelu"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +// CHECK-LABEL: func.func @test_thresholded_relu_unranked +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: [[VAR_0_:%.+]] = "onnx.ThresholdedRelu"([[PARAM_0_]]) {alpha = 1.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return [[VAR_0_]] : tensor<*xf32> +} + +// ----- + func.func @test_sigmoid(%arg0 : tensor<10x10xf32>) -> tensor<10x10xf32> { %0 = "onnx.Sigmoid"(%arg0) : (tensor<10x10xf32>) -> tensor<10x10xf32> "func.return"(%0) : (tensor<10x10xf32>) -> () @@ -742,6 +847,32 @@ func.func @test_hardsigmoid_f16(%arg0: tensor<3xf16>) -> tensor<3xf16> { // ----- +func.func @test_hardsigmoid_dynamic(%arg0: tensor) -> tensor { + %0 = "onnx.HardSigmoid"(%arg0) {alpha = 0.166666672 : f32, beta = 5.000000e-01 : f32} : (tensor) -> tensor + return %0 : tensor +// CHECK-LABEL: func.func @test_hardsigmoid_dynamic +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { +// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<3.000000e+00> : tensor<1x1x1xf16>}> : () -> tensor<1x1x1xf16> +// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<1.666260e-01> : tensor<1x1x1xf16>}> : () -> tensor<1x1x1xf16> +// CHECK: [[VAR_2_:%.+]] = tosa.add [[PARAM_0_]], [[VAR_0_]] : (tensor, tensor<1x1x1xf16>) -> tensor +// CHECK: [[VAR_3_:%.+]] = tosa.clamp [[VAR_2_]] {max_fp = 6.000000e+00 : f32, max_int = 6 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor) -> tensor +// CHECK: [[VAR_4_:%.+]] = tosa.mul [[VAR_3_]], [[VAR_1_]] {shift = 0 : i8} : (tensor, tensor<1x1x1xf16>) -> tensor +// CHECK: return [[VAR_4_]] : tensor +} + +// ----- + +func.func @test_hardsigmoid_unranked(%arg0: tensor<*xf16>) -> tensor<*xf16> { + %0 = "onnx.HardSigmoid"(%arg0) {alpha = 0.166666672 : f32, beta = 5.000000e-01 : f32} : (tensor<*xf16>) -> tensor<*xf16> + return %0 : tensor<*xf16> +// CHECK-LABEL: func.func @test_hardsigmoid_unranked +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf16>) -> tensor<*xf16> { +// CHECK: [[VAR_0_:%.+]] = "onnx.HardSigmoid"([[PARAM_0_]]) {alpha = 0.166666672 : f32, beta = 5.000000e-01 : f32} : (tensor<*xf16>) -> tensor<*xf16> +// CHECK: return [[VAR_0_]] : tensor<*xf16> +} + +// ----- + func.func @test_elu_f32(%arg0: tensor<3xf32>) -> tensor<3xf32> { %0 = "onnx.Elu"(%arg0) {alpha = 0.166666672 : f32} : (tensor<3xf32>) -> tensor<3xf32> return %0 : tensor<3xf32> @@ -774,6 +905,26 @@ func.func @test_elu_f16(%arg0: tensor<3xf16>) -> tensor<3xf16> { // CHECK: return [[VAR_7_]] } +// ----- + +func.func @test_elu_unranked(%arg0: tensor<*xf32>) -> tensor<3xf32> { + %0 = "onnx.Elu"(%arg0) {alpha = 0.166666672 : f32} : (tensor<*xf32>) -> tensor<3xf32> + return %0 : tensor<3xf32> +// CHECK-LABEL: func.func @test_elu_unranked +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>) -> tensor<3xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<0.166666672> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK-DAG: [[VAR_3_:%.+]] = tosa.exp [[PARAM_0_]] : (tensor<*xf32>) -> tensor<3xf32> +// CHECK: [[VAR_4_:%.+]] = tosa.sub [[VAR_3_]], [[VAR_0_]] : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xf32> +// CHECK-DAG: [[VAR_5_:%.+]] = tosa.mul [[VAR_4_]], [[VAR_1_]] {shift = 0 : i8} : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xf32> +// CHECK-DAG: [[VAR_6_:%.+]] = tosa.greater_equal [[PARAM_0_]], [[VAR_2_]] : (tensor<*xf32>, tensor<1xf32>) -> tensor<*xi1> +// CHECK: [[VAR_7_:%.+]] = tosa.select [[VAR_6_]], [[PARAM_0_]], [[VAR_5_]] : (tensor<*xi1>, tensor<*xf32>, tensor<3xf32>) -> tensor<3xf32> +// CHECK: return [[VAR_7_]] : tensor<3xf32> +// CHECK: } +} + + // ----- func.func @test_and(%arg0: tensor<13x21x1xi1>, %arg1: tensor<13x21x1xi1>) -> tensor<13x21x1xi1> {