Skip to content

Commit

Permalink
Merge branch 'tiagot.onnx_to_tosa_pow' of https://github.com/xilinx/o…
Browse files Browse the repository at this point in the history
…nnx-mlir into tiagot.legalize_sqrt
  • Loading branch information
ttjost committed Jan 4, 2024
2 parents 249d364 + 97921ba commit 71020ac
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 2 deletions.
12 changes: 10 additions & 2 deletions src/Conversion/ONNXToTOSA/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -434,8 +434,16 @@ class ONNXHardSigmoidOpLoweringToTOSA
static void populateLoweringONNXElementwiseBinaryTemplateOpToTOSAPattern(
RewritePatternSet &patterns, TypeConverter &typeConverter,
MLIRContext *ctx) {
patterns.insert<ONNXBinaryElementwiseOpLoweringToTOSA<ONNXAddOp,
mlir::tosa::AddOp, IsIntOrFloat>,
patterns.insert<ONNXBinaryElementwiseOpLoweringToTOSA<ONNXAndOp,
mlir::tosa::LogicalAndOp, IsBool>,
ONNXBinaryElementwiseOpLoweringToTOSA<ONNXBitwiseAndOp,
mlir::tosa::BitwiseAndOp, IsInt>,
ONNXBinaryElementwiseOpLoweringToTOSA<ONNXOrOp, mlir::tosa::LogicalOrOp,
IsBool>,
ONNXBinaryElementwiseOpLoweringToTOSA<ONNXBitwiseOrOp,
mlir::tosa::BitwiseOrOp, IsInt>,
ONNXBinaryElementwiseOpLoweringToTOSA<ONNXAddOp, mlir::tosa::AddOp,
IsIntOrFloat>,
ONNXBinaryElementwiseOpLoweringToTOSA<ONNXSubOp, mlir::tosa::SubOp,
IsIntOrFloat>,
ONNXBinaryElementwiseOpLoweringToTOSA<ONNXPowOp, mlir::tosa::PowOp,
Expand Down
81 changes: 81 additions & 0 deletions test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -475,3 +475,84 @@ func.func @test_hardsigmoid_f16(%arg0: tensor<3xf16>) -> tensor<3xf16> {
// CHECK: [[VAR_4_:%.+]] = "tosa.mul"([[VAR_3_]], [[VAR_1_]]) <{shift = 0 : i32}> : (tensor<3xf16>, tensor<3xf16>) -> tensor<3xf16>
// CHECK: return [[VAR_4_]] : tensor<3xf16>
}
// -----

func.func @test_and(%arg0: tensor<13x21x1xi1>, %arg1: tensor<13x21x1xi1>) -> tensor<13x21x1xi1> {
%0 = "onnx.And"(%arg0, %arg1) : (tensor<13x21x1xi1>, tensor<13x21x1xi1>) -> tensor<13x21x1xi1>
"func.return"(%0) : (tensor<13x21x1xi1>) -> ()
// CHECK-LABEL: func @test_and
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xi1>, [[PARAM_1_:%.+]]: tensor<13x21x1xi1>) -> tensor<13x21x1xi1> {
// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.logical_and"([[PARAM_0_]], [[PARAM_1_]]) : (tensor<13x21x1xi1>, tensor<13x21x1xi1>) -> tensor<13x21x1xi1>
}

// -----

func.func @test_and_broadcast(%arg0: tensor<13x21x1xi1>, %arg1: tensor<1xi1>) -> tensor<13x21x1xi1> {
%0 = "onnx.And"(%arg0, %arg1) : (tensor<13x21x1xi1>, tensor<1xi1>) -> tensor<13x21x1xi1>
"func.return"(%0) : (tensor<13x21x1xi1>) -> ()
// CHECK-LABEL: func.func @test_and_broadcast
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xi1>, [[PARAM_1_:%.+]]: tensor<1xi1>) -> tensor<13x21x1xi1> {
// CHECK: [[VAR_0_:%.+]] = "tosa.reshape"([[PARAM_1_]]) <{new_shape = array<i64: 1, 1, 1>}> : (tensor<1xi1>) -> tensor<1x1x1xi1>
// CHECK: [[VAR_1_:%.+]] = "tosa.logical_and"([[PARAM_0_]], [[VAR_0_]]) : (tensor<13x21x1xi1>, tensor<1x1x1xi1>) -> tensor<13x21x1xi1>
// CHECK: return [[VAR_1_]] : tensor<13x21x1xi1>
}
// -----

func.func @test_bitwise_and(%arg0: tensor<13x21x1xi64>, %arg1: tensor<13x21x1xi64>) -> tensor<13x21x1xi64> {
%0 = "onnx.BitwiseAnd"(%arg0, %arg1) : (tensor<13x21x1xi64>, tensor<13x21x1xi64>) -> tensor<13x21x1xi64>
"func.return"(%0) : (tensor<13x21x1xi64>) -> ()
// CHECK-LABEL: func @test_bitwise_and
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xi64>, [[PARAM_1_:%.+]]: tensor<13x21x1xi64>) -> tensor<13x21x1xi64> {
// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.bitwise_and"([[PARAM_0_]], [[PARAM_1_]]) : (tensor<13x21x1xi64>, tensor<13x21x1xi64>) -> tensor<13x21x1xi64>
}
// -----

func.func @test_bitwise_and_broadcast(%arg0: tensor<13x21x1xi64>, %arg1: tensor<1xi64>) -> tensor<13x21x1xi64> {
%0 = "onnx.BitwiseAnd"(%arg0, %arg1) : (tensor<13x21x1xi64>, tensor<1xi64>) -> tensor<13x21x1xi64>
"func.return"(%0) : (tensor<13x21x1xi64>) -> ()
// CHECK-LABEL: func.func @test_bitwise_and_broadcast
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xi64>, [[PARAM_1_:%.+]]: tensor<1xi64>) -> tensor<13x21x1xi64> {
// CHECK: [[VAR_0_:%.+]] = "tosa.reshape"([[PARAM_1_]]) <{new_shape = array<i64: 1, 1, 1>}> : (tensor<1xi64>) -> tensor<1x1x1xi64>
// CHECK: [[VAR_1_:%.+]] = "tosa.bitwise_and"([[PARAM_0_]], [[VAR_0_]]) : (tensor<13x21x1xi64>, tensor<1x1x1xi64>) -> tensor<13x21x1xi64>
// CHECK: return [[VAR_1_]] : tensor<13x21x1xi64>
}
// -----

func.func @test_or(%arg0: tensor<13x21x1xi1>, %arg1: tensor<13x21x1xi1>) -> tensor<13x21x1xi1> {
%0 = "onnx.Or"(%arg0, %arg1) : (tensor<13x21x1xi1>, tensor<13x21x1xi1>) -> tensor<13x21x1xi1>
"func.return"(%0) : (tensor<13x21x1xi1>) -> ()
// CHECK-LABEL: func @test_or
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xi1>, [[PARAM_1_:%.+]]: tensor<13x21x1xi1>) -> tensor<13x21x1xi1> {
// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.logical_or"([[PARAM_0_]], [[PARAM_1_]]) : (tensor<13x21x1xi1>, tensor<13x21x1xi1>) -> tensor<13x21x1xi1>
}
// -----

func.func @test_or_broadcast(%arg0: tensor<13x21x1xi1>, %arg1: tensor<1xi1>) -> tensor<13x21x1xi1> {
%0 = "onnx.Or"(%arg0, %arg1) : (tensor<13x21x1xi1>, tensor<1xi1>) -> tensor<13x21x1xi1>
"func.return"(%0) : (tensor<13x21x1xi1>) -> ()
// CHECK-LABEL: func.func @test_or_broadcast
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xi1>, [[PARAM_1_:%.+]]: tensor<1xi1>) -> tensor<13x21x1xi1> {
// CHECK: [[VAR_0_:%.+]] = "tosa.reshape"([[PARAM_1_]]) <{new_shape = array<i64: 1, 1, 1>}> : (tensor<1xi1>) -> tensor<1x1x1xi1>
// CHECK: [[VAR_1_:%.+]] = "tosa.logical_or"([[PARAM_0_]], [[VAR_0_]]) : (tensor<13x21x1xi1>, tensor<1x1x1xi1>) -> tensor<13x21x1xi1>
// CHECK: return [[VAR_1_]] : tensor<13x21x1xi1>
}
// -----

func.func @test_bitwise_or(%arg0: tensor<13x21x1xi64>, %arg1: tensor<13x21x1xi64>) -> tensor<13x21x1xi64> {
%0 = "onnx.BitwiseOr"(%arg0, %arg1) : (tensor<13x21x1xi64>, tensor<13x21x1xi64>) -> tensor<13x21x1xi64>
"func.return"(%0) : (tensor<13x21x1xi64>) -> ()
// CHECK-LABEL: func @test_bitwise_or
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xi64>, [[PARAM_1_:%.+]]: tensor<13x21x1xi64>) -> tensor<13x21x1xi64> {
// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.bitwise_or"([[PARAM_0_]], [[PARAM_1_]]) : (tensor<13x21x1xi64>, tensor<13x21x1xi64>) -> tensor<13x21x1xi64>
}
// -----

func.func @test_bitwise_or_broadcast(%arg0: tensor<13x21x1xi64>, %arg1: tensor<1xi64>) -> tensor<13x21x1xi64> {
%0 = "onnx.BitwiseOr"(%arg0, %arg1) : (tensor<13x21x1xi64>, tensor<1xi64>) -> tensor<13x21x1xi64>
"func.return"(%0) : (tensor<13x21x1xi64>) -> ()
// CHECK-LABEL: func.func @test_bitwise_or_broadcast
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xi64>, [[PARAM_1_:%.+]]: tensor<1xi64>) -> tensor<13x21x1xi64> {
// CHECK: [[VAR_0_:%.+]] = "tosa.reshape"([[PARAM_1_]]) <{new_shape = array<i64: 1, 1, 1>}> : (tensor<1xi64>) -> tensor<1x1x1xi64>
// CHECK: [[VAR_1_:%.+]] = "tosa.bitwise_or"([[PARAM_0_]], [[VAR_0_]]) : (tensor<13x21x1xi64>, tensor<1x1x1xi64>) -> tensor<13x21x1xi64>
// CHECK: return [[VAR_1_]] : tensor<13x21x1xi64>
}

0 comments on commit 71020ac

Please sign in to comment.