diff --git a/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp b/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp index 612632e47a..476cd53a75 100644 --- a/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp +++ b/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp @@ -31,6 +31,20 @@ struct TOSADialectOp { using Op = mlir::tosa::NegateOp; }; +struct AbsIOSupportedTypes { + static LogicalResult checkType( + ConversionPatternRewriter &rewriter, Type scalarType, Operation *op) { + if (!mlir::isa( + scalarType) && + !scalarType.isSignlessInteger(/*width=*/32)) { + return rewriter.notifyMatchFailure(op, + "this operation only supports signless 32 integer or fp16, fp32" + " or bf16"); + } + return success(); + } +}; + struct IsIntOrFloat { static LogicalResult checkType( ConversionPatternRewriter &rewriter, Type scalarType, Operation *op) { @@ -333,7 +347,9 @@ static void populateLoweringONNXElementwiseUnaryTemplateOpToTOSAPattern( ONNXElementwiseUnaryOpLoweringToTOSA, ONNXElementwiseUnaryOpLoweringToTOSA>(typeConverter, ctx); + IsIntOrFloat, IsIntOrFloat>, + ONNXElementwiseUnaryOpLoweringToTOSA>(typeConverter, ctx); } void populateLoweringONNXElementwiseOpToTOSAPattern(ConversionTarget &target, diff --git a/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir b/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir index fe10cddb1c..8a967039e0 100644 --- a/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir +++ b/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir @@ -297,3 +297,25 @@ func.func @test_div_decomposed_broadcast(%arg0: tensor<13x21x1xf32>, %arg1: tens // CHECK-NEXT: [[VAR_1_:%.+]] = "tosa.reshape"([[VAR_0_]]) <{new_shape = array}> : (tensor<1xf32>) -> tensor<1x1x1xf32> // CHECK-NEXT: [[VAR_2_:%.+]] = "tosa.mul"([[PARAM_0_]], [[VAR_1_]]) <{shift = 0 : i32}> : (tensor<13x21x1xf32>, tensor<1x1x1xf32>) -> tensor<13x21x1xf32> } + +// ----- + +func.func @test_abs_i32(%arg0: tensor<3xi32>) -> tensor<3xi32> { + %0 = "onnx.Abs"(%arg0) : (tensor<3xi32>) -> tensor<3xi32> + return %0 : tensor<3xi32> +// CHECK-LABEL: func @test_abs_i32 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3xi32>) -> tensor<3xi32> +// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.abs"([[PARAM_0_]]) : (tensor<3xi32>) -> tensor<3xi32> +// CHECK-NEXT: return [[VAR_0_]] : tensor<3xi32> +// CHECK-NEXT: } +} + +func.func @test_abs_bf16(%arg0: tensor<3xbf16>) -> tensor<3xbf16> { + %0 = "onnx.Abs"(%arg0) : (tensor<3xbf16>) -> tensor<3xbf16> + return %0 : tensor<3xbf16> +// CHECK-LABEL: func @test_abs_bf16 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3xbf16>) -> tensor<3xbf16> +// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.abs"([[PARAM_0_]]) : (tensor<3xbf16>) -> tensor<3xbf16> +// CHECK-NEXT: return [[VAR_0_]] : tensor<3xbf16> +// CHECK-NEXT: } +}