Skip to content

Commit

Permalink
feat: Add onnx.abs lowering to TOSA
Browse files Browse the repository at this point in the history
  • Loading branch information
roberteg16 committed Dec 15, 2023
1 parent 98687c0 commit f4eb920
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 1 deletion.
18 changes: 17 additions & 1 deletion src/Conversion/ONNXToTOSA/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,20 @@ struct TOSADialectOp<ONNXNegOp> {
using Op = mlir::tosa::NegateOp;
};

struct AbsIOSupportedTypes {
static LogicalResult checkType(
ConversionPatternRewriter &rewriter, Type scalarType, Operation *op) {
if (!mlir::isa<mlir::BFloat16Type, mlir::Float16Type, mlir::Float32Type>(
scalarType) &&
!scalarType.isSignlessInteger(/*width=*/32)) {
return rewriter.notifyMatchFailure(op,
"this operation only supports signless 32 integer or fp16, fp32"
" or bf16");
}
return success();
}
};

struct IsIntOrFloat {
static LogicalResult checkType(
ConversionPatternRewriter &rewriter, Type scalarType, Operation *op) {
Expand Down Expand Up @@ -333,7 +347,9 @@ static void populateLoweringONNXElementwiseUnaryTemplateOpToTOSAPattern(
ONNXElementwiseUnaryOpLoweringToTOSA<ONNXTanhOp, mlir::tosa::TanhOp,
IsIntOrFloat, IsIntOrFloat>,
ONNXElementwiseUnaryOpLoweringToTOSA<ONNXSigmoidOp, mlir::tosa::SigmoidOp,
IsIntOrFloat, IsIntOrFloat>>(typeConverter, ctx);
IsIntOrFloat, IsIntOrFloat>,
ONNXElementwiseUnaryOpLoweringToTOSA<ONNXAbsOp, mlir::tosa::AbsOp,
AbsIOSupportedTypes, AbsIOSupportedTypes>>(typeConverter, ctx);
}

void populateLoweringONNXElementwiseOpToTOSAPattern(ConversionTarget &target,
Expand Down
22 changes: 22 additions & 0 deletions test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -297,3 +297,25 @@ func.func @test_div_decomposed_broadcast(%arg0: tensor<13x21x1xf32>, %arg1: tens
// CHECK-NEXT: [[VAR_1_:%.+]] = "tosa.reshape"([[VAR_0_]]) <{new_shape = array<i64: 1, 1, 1>}> : (tensor<1xf32>) -> tensor<1x1x1xf32>
// CHECK-NEXT: [[VAR_2_:%.+]] = "tosa.mul"([[PARAM_0_]], [[VAR_1_]]) <{shift = 0 : i32}> : (tensor<13x21x1xf32>, tensor<1x1x1xf32>) -> tensor<13x21x1xf32>
}

// -----

func.func @test_abs_i32(%arg0: tensor<3xi32>) -> tensor<3xi32> {
%0 = "onnx.Abs"(%arg0) : (tensor<3xi32>) -> tensor<3xi32>
return %0 : tensor<3xi32>
// CHECK-LABEL: func @test_abs_i32
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3xi32>) -> tensor<3xi32>
// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.abs"([[PARAM_0_]]) : (tensor<3xi32>) -> tensor<3xi32>
// CHECK-NEXT: return [[VAR_0_]] : tensor<3xi32>
// CHECK-NEXT: }
}

func.func @test_abs_bf16(%arg0: tensor<3xbf16>) -> tensor<3xbf16> {
%0 = "onnx.Abs"(%arg0) : (tensor<3xbf16>) -> tensor<3xbf16>
return %0 : tensor<3xbf16>
// CHECK-LABEL: func @test_abs_bf16
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3xbf16>) -> tensor<3xbf16>
// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.abs"([[PARAM_0_]]) : (tensor<3xbf16>) -> tensor<3xbf16>
// CHECK-NEXT: return [[VAR_0_]] : tensor<3xbf16>
// CHECK-NEXT: }
}

0 comments on commit f4eb920

Please sign in to comment.