Skip to content

Commit

Permalink
feat: Add onnx.erf lowering to TOSA
Browse files Browse the repository at this point in the history
  • Loading branch information
roberteg16 committed Dec 15, 2023
1 parent f4eb920 commit 36eb5f4
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 1 deletion.
16 changes: 15 additions & 1 deletion src/Conversion/ONNXToTOSA/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,18 @@ struct AbsIOSupportedTypes {
}
};

struct ErfIOSupportedTypes {
static LogicalResult checkType(
ConversionPatternRewriter &rewriter, Type scalarType, Operation *op) {
if (!mlir::isa<mlir::BFloat16Type, mlir::Float16Type, mlir::Float32Type>(
scalarType)) {
return rewriter.notifyMatchFailure(op,
"this operation only supports fp16, fp32 or bf16");
}
return success();
}
};

struct IsIntOrFloat {
static LogicalResult checkType(
ConversionPatternRewriter &rewriter, Type scalarType, Operation *op) {
Expand Down Expand Up @@ -349,7 +361,9 @@ static void populateLoweringONNXElementwiseUnaryTemplateOpToTOSAPattern(
ONNXElementwiseUnaryOpLoweringToTOSA<ONNXSigmoidOp, mlir::tosa::SigmoidOp,
IsIntOrFloat, IsIntOrFloat>,
ONNXElementwiseUnaryOpLoweringToTOSA<ONNXAbsOp, mlir::tosa::AbsOp,
AbsIOSupportedTypes, AbsIOSupportedTypes>>(typeConverter, ctx);
AbsIOSupportedTypes, AbsIOSupportedTypes>,
ONNXElementwiseUnaryOpLoweringToTOSA<ONNXErfOp, mlir::tosa::ErfOp,
ErfIOSupportedTypes, ErfIOSupportedTypes>>(typeConverter, ctx);
}

void populateLoweringONNXElementwiseOpToTOSAPattern(ConversionTarget &target,
Expand Down
22 changes: 22 additions & 0 deletions test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -319,3 +319,25 @@ func.func @test_abs_bf16(%arg0: tensor<3xbf16>) -> tensor<3xbf16> {
// CHECK-NEXT: return [[VAR_0_]] : tensor<3xbf16>
// CHECK-NEXT: }
}

// -----

func.func @test_erf_f32(%arg0: tensor<3xf32>) -> tensor<3xf32> {
%0 = "onnx.Erf"(%arg0) : (tensor<3xf32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
// CHECK-LABEL: func @test_erf_f32
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3xf32>) -> tensor<3xf32>
// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.erf"([[PARAM_0_]]) : (tensor<3xf32>) -> tensor<3xf32>
// CHECK-NEXT: return [[VAR_0_]] : tensor<3xf32>
// CHECK-NEXT: }
}

func.func @test_erf_bf16(%arg0: tensor<3xbf16>) -> tensor<3xbf16> {
%0 = "onnx.Erf"(%arg0) : (tensor<3xbf16>) -> tensor<3xbf16>
return %0 : tensor<3xbf16>
// CHECK-LABEL: func @test_erf_bf16
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3xbf16>) -> tensor<3xbf16>
// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.erf"([[PARAM_0_]]) : (tensor<3xbf16>) -> tensor<3xbf16>
// CHECK-NEXT: return [[VAR_0_]] : tensor<3xbf16>
// CHECK-NEXT: }
}

0 comments on commit 36eb5f4

Please sign in to comment.