Skip to content

Commit

Permalink
feat(ONNXTOTosa): legalization of ONNXXorOp and ONNXBitwiseXorOp.
Browse files Browse the repository at this point in the history
  • Loading branch information
ttjost committed Jan 4, 2024
1 parent a320dec commit 835d343
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/Conversion/ONNXToTOSA/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,10 @@ static void populateLoweringONNXElementwiseBinaryTemplateOpToTOSAPattern(
IsBool>,
ONNXBinaryElementwiseOpLoweringToTOSA<ONNXBitwiseOrOp,
mlir::tosa::BitwiseOrOp, IsInt>,
ONNXBinaryElementwiseOpLoweringToTOSA<ONNXXorOp, mlir::tosa::LogicalXorOp,
IsBool>,
ONNXBinaryElementwiseOpLoweringToTOSA<ONNXBitwiseXorOp,
mlir::tosa::BitwiseXorOp, IsInt>,
ONNXBinaryElementwiseOpLoweringToTOSA<ONNXAddOp, mlir::tosa::AddOp,
IsIntOrFloat>,
ONNXBinaryElementwiseOpLoweringToTOSA<ONNXSubOp, mlir::tosa::SubOp,
Expand Down
43 changes: 43 additions & 0 deletions test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -555,4 +555,47 @@ func.func @test_bitwise_or_broadcast(%arg0: tensor<13x21x1xi64>, %arg1: tensor<1
// CHECK: [[VAR_0_:%.+]] = "tosa.reshape"([[PARAM_1_]]) <{new_shape = array<i64: 1, 1, 1>}> : (tensor<1xi64>) -> tensor<1x1x1xi64>
// CHECK: [[VAR_1_:%.+]] = "tosa.bitwise_or"([[PARAM_0_]], [[VAR_0_]]) : (tensor<13x21x1xi64>, tensor<1x1x1xi64>) -> tensor<13x21x1xi64>
// CHECK: return [[VAR_1_]] : tensor<13x21x1xi64>

}

// -----

func.func @test_xor(%arg0: tensor<13x21x1xi1>, %arg1: tensor<13x21x1xi1>) -> tensor<13x21x1xi1> {
%0 = "onnx.Xor"(%arg0, %arg1) : (tensor<13x21x1xi1>, tensor<13x21x1xi1>) -> tensor<13x21x1xi1>
"func.return"(%0) : (tensor<13x21x1xi1>) -> ()
// CHECK-LABEL: func @test_xor
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xi1>, [[PARAM_1_:%.+]]: tensor<13x21x1xi1>) -> tensor<13x21x1xi1> {
// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.logical_xor"([[PARAM_0_]], [[PARAM_1_]]) : (tensor<13x21x1xi1>, tensor<13x21x1xi1>) -> tensor<13x21x1xi1>
}

// -----

func.func @test_xor_broadcast(%arg0: tensor<13x21x1xi1>, %arg1: tensor<1xi1>) -> tensor<13x21x1xi1> {
%0 = "onnx.Xor"(%arg0, %arg1) : (tensor<13x21x1xi1>, tensor<1xi1>) -> tensor<13x21x1xi1>
"func.return"(%0) : (tensor<13x21x1xi1>) -> ()
// CHECK-LABEL: func.func @test_xor_broadcast
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xi1>, [[PARAM_1_:%.+]]: tensor<1xi1>) -> tensor<13x21x1xi1> {
// CHECK: [[VAR_0_:%.+]] = "tosa.reshape"([[PARAM_1_]]) <{new_shape = array<i64: 1, 1, 1>}> : (tensor<1xi1>) -> tensor<1x1x1xi1>
// CHECK: [[VAR_1_:%.+]] = "tosa.logical_xor"([[PARAM_0_]], [[VAR_0_]]) : (tensor<13x21x1xi1>, tensor<1x1x1xi1>) -> tensor<13x21x1xi1>
// CHECK: return [[VAR_1_]] : tensor<13x21x1xi1>
}
// -----

func.func @test_bitwise_xor(%arg0: tensor<13x21x1xi64>, %arg1: tensor<13x21x1xi64>) -> tensor<13x21x1xi64> {
%0 = "onnx.BitwiseXor"(%arg0, %arg1) : (tensor<13x21x1xi64>, tensor<13x21x1xi64>) -> tensor<13x21x1xi64>
"func.return"(%0) : (tensor<13x21x1xi64>) -> ()
// CHECK-LABEL: func @test_bitwise_xor
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xi64>, [[PARAM_1_:%.+]]: tensor<13x21x1xi64>) -> tensor<13x21x1xi64> {
// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.bitwise_xor"([[PARAM_0_]], [[PARAM_1_]]) : (tensor<13x21x1xi64>, tensor<13x21x1xi64>) -> tensor<13x21x1xi64>
}
// -----

func.func @test_bitwise_xor_broadcast(%arg0: tensor<13x21x1xi64>, %arg1: tensor<1xi64>) -> tensor<13x21x1xi64> {
%0 = "onnx.BitwiseXor"(%arg0, %arg1) : (tensor<13x21x1xi64>, tensor<1xi64>) -> tensor<13x21x1xi64>
"func.return"(%0) : (tensor<13x21x1xi64>) -> ()
// CHECK-LABEL: func.func @test_bitwise_xor_broadcast
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xi64>, [[PARAM_1_:%.+]]: tensor<1xi64>) -> tensor<13x21x1xi64> {
// CHECK: [[VAR_0_:%.+]] = "tosa.reshape"([[PARAM_1_]]) <{new_shape = array<i64: 1, 1, 1>}> : (tensor<1xi64>) -> tensor<1x1x1xi64>
// CHECK: [[VAR_1_:%.+]] = "tosa.bitwise_xor"([[PARAM_0_]], [[VAR_0_]]) : (tensor<13x21x1xi64>, tensor<1x1x1xi64>) -> tensor<13x21x1xi64>
// CHECK: return [[VAR_1_]] : tensor<13x21x1xi64>
}

0 comments on commit 835d343

Please sign in to comment.