Skip to content

Commit

Permalink
feat(OnnxToTosa): legalization for onnx.BitwiseNot and onnx.Not.
Browse files Browse the repository at this point in the history
  • Loading branch information
ttjost committed Jan 3, 2024
1 parent 0c6a798 commit 482dd65
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 2 deletions.
3 changes: 2 additions & 1 deletion src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ void FrontendToTosaLoweringPass::runOnOperation() {
// conversion failures. Quantized types are not supported right now.
TypeConverter typeConverter;
typeConverter.addConversion([](Type type) -> std::optional<Type> {
if (isTOSASignedInt(type) || isTOSAFloat(type) || type.isa<NoneType>())
if (isTOSASignedInt(type) || isTOSAFloat(type) || type.isa<NoneType>() ||
isTOSABool(type))
return type;
return std::nullopt;
});
Expand Down
28 changes: 27 additions & 1 deletion src/Conversion/ONNXToTOSA/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,29 @@ struct IsIntOrFloat {
ConversionPatternRewriter &rewriter, Type scalarType, Operation *op) {
if (!isTOSAFloat(scalarType) && !isTOSASignedInt(scalarType)) {
return rewriter.notifyMatchFailure(
op, "this operation only support signed integer or float types");
op, "this operation only supports signed integer or float types");
}
return success();
}
};

struct IsInt {
static LogicalResult checkType(
ConversionPatternRewriter &rewriter, Type scalarType, Operation *op) {
if (!isTOSASignedInt(scalarType)) {
return rewriter.notifyMatchFailure(
op, "this operation only supports float types");
}
return success();
}
};

struct IsBool {
static LogicalResult checkType(
ConversionPatternRewriter &rewriter, Type scalarType, Operation *op) {
if (!isTOSABool(scalarType)) {
return rewriter.notifyMatchFailure(
op, "this operation only supports bool type");
}
return success();
}
Expand Down Expand Up @@ -406,6 +428,10 @@ static void populateLoweringONNXElementwiseUnaryTemplateOpToTOSAPattern(
IsIntOrFloat, IsIntOrFloat>,
ONNXElementwiseUnaryOpLoweringToTOSA<ONNXSigmoidOp, mlir::tosa::SigmoidOp,
IsIntOrFloat, IsIntOrFloat>,
ONNXElementwiseUnaryOpLoweringToTOSA<ONNXBitwiseNotOp,
mlir::tosa::BitwiseNotOp, IsInt, IsInt>,
ONNXElementwiseUnaryOpLoweringToTOSA<ONNXNotOp, mlir::tosa::LogicalNotOp,
IsBool, IsBool>,
ONNXElementwiseUnaryOpLoweringToTOSA<ONNXAbsOp, mlir::tosa::AbsOp,
AbsIOSupportedTypes, AbsIOSupportedTypes>,
ONNXElementwiseUnaryOpLoweringToTOSA<ONNXErfOp, mlir::tosa::ErfOp,
Expand Down
5 changes: 5 additions & 0 deletions src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,11 @@ std::optional<mlir::Value> convertReduceOpCommon(
// Check for valid TOSA types.
//===----------------------------------------------------------------------===//

inline bool isTOSABool(mlir::Type type) {
mlir::IntegerType intType = type.dyn_cast<mlir::IntegerType>();
return intType && intType.isSignless() && intType.getWidth() == 1;
}

inline bool isTOSASignedInt(mlir::Type type) {
mlir::IntegerType intType = type.dyn_cast<mlir::IntegerType>();
std::set<unsigned> intWidth{8, 16, 32, 48, 64};
Expand Down
24 changes: 24 additions & 0 deletions test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,30 @@ func.func @test_erf_bf16(%arg0: tensor<3xbf16>) -> tensor<3xbf16> {

// -----

func.func @test_bitwise_not(%arg0 : tensor<10x10xi32>) -> tensor<10x10xi32> {
%0 = "onnx.BitwiseNot"(%arg0) : (tensor<10x10xi32>) -> tensor<10x10xi32>
"func.return"(%0) : (tensor<10x10xi32>) -> ()
// CHECK-LABEL: func @test_bitwise_not
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xi32>) -> tensor<10x10xi32> {
// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.bitwise_not"([[PARAM_0_]]) : (tensor<10x10xi32>) -> tensor<10x10xi32>
// CHECK-NEXT: return [[VAR_0_]] : tensor<10x10xi32>
// CHECK-NEXT: }
}

// -----

func.func @test_not(%arg0 : tensor<10x10xi1>) -> tensor<10x10xi1> {
%0 = "onnx.Not"(%arg0) : (tensor<10x10xi1>) -> tensor<10x10xi1>
"func.return"(%0) : (tensor<10x10xi1>) -> ()
// CHECK-LABEL: func @test_not
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xi1>) -> tensor<10x10xi1> {
// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.logical_not"([[PARAM_0_]]) : (tensor<10x10xi1>) -> tensor<10x10xi1>
// CHECK-NEXT: return [[VAR_0_]] : tensor<10x10xi1>
// CHECK-NEXT: }
}

// -----

// Default values: alpha = 0.2, beta = 0.5
func.func @test_hardsigmoid_default_values_f32(%arg0: tensor<3xf32>) -> tensor<3xf32> {
%0 = "onnx.HardSigmoid"(%arg0) : (tensor<3xf32>) -> tensor<3xf32>
Expand Down

0 comments on commit 482dd65

Please sign in to comment.