Skip to content

Commit

Permalink
Merge branch 'feature/onnx-to-tosa' of https://github.com/xilinx/onnx…
Browse files Browse the repository at this point in the history
…-mlir into tiagot.onnx_to_tosa_xor
  • Loading branch information
ttjost committed Jan 5, 2024
2 parents 92dc63f + c9b84d3 commit f03513d
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 2 deletions.
6 changes: 5 additions & 1 deletion src/Conversion/ONNXToTOSA/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,11 @@ static void populateLoweringONNXElementwiseBinaryTemplateOpToTOSAPattern(
ONNXBinaryElementwiseOpLoweringToTOSA<ONNXSubOp, mlir::tosa::SubOp,
IsIntOrFloat>,
ONNXBinaryElementwiseOpLoweringToTOSA<ONNXPowOp, mlir::tosa::PowOp,
IsFloat>>(typeConverter, ctx);
IsFloat>,
ONNXBinaryElementwiseOpLoweringToTOSA<ONNXMinOp, mlir::tosa::MinimumOp,
IsIntOrFloat>,
ONNXBinaryElementwiseOpLoweringToTOSA<ONNXMaxOp, mlir::tosa::MaximumOp,
IsIntOrFloat>>(typeConverter, ctx);
}

static void populateLoweringONNXElementwiseUnaryTemplateOpToTOSAPattern(
Expand Down
43 changes: 42 additions & 1 deletion test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -597,4 +597,45 @@ func.func @test_bitwise_xor_broadcast(%arg0: tensor<13x21x1xi64>, %arg1: tensor<
// CHECK: [[VAR_0_:%.+]] = "tosa.reshape"([[PARAM_1_]]) <{new_shape = array<i64: 1, 1, 1>}> : (tensor<1xi64>) -> tensor<1x1x1xi64>
// CHECK: [[VAR_1_:%.+]] = "tosa.bitwise_xor"([[PARAM_0_]], [[VAR_0_]]) : (tensor<13x21x1xi64>, tensor<1x1x1xi64>) -> tensor<13x21x1xi64>
// CHECK: return [[VAR_1_]] : tensor<13x21x1xi64>
}
}

// -----

func.func @test_min(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x1xf32> {
%0 = "onnx.Min"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xf32>
"func.return"(%0) : (tensor<13x21x1xf32>) -> ()
// CHECK-LABEL: func @test_min
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xf32>, [[PARAM_1_:%.+]]: tensor<13x21x1xf32>) -> tensor<13x21x1xf32> {
// CHECK-NEXT: [[VAR_1_:%.+]] = "tosa.minimum"([[PARAM_0_]], [[PARAM_1_]]) : (tensor<13x21x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xf32>
}

// -----

func.func @test_min_broadcast(%arg0: tensor<13x21x1xf32>, %arg1: tensor<1xf32>) -> tensor<13x21x1xf32> {
%0 = "onnx.Min"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<1xf32>) -> tensor<13x21x1xf32>
"func.return"(%0) : (tensor<13x21x1xf32>) -> ()
// CHECK-LABEL: func @test_min_broadcast
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xf32>, [[PARAM_1_:%.+]]: tensor<1xf32>) -> tensor<13x21x1xf32> {
// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.reshape"([[PARAM_1_]]) <{new_shape = array<i64: 1, 1, 1>}> : (tensor<1xf32>) -> tensor<1x1x1xf32>
// CHECK-NEXT: [[VAR_1_:%.+]] = "tosa.minimum"([[PARAM_0_]], [[VAR_0_]]) : (tensor<13x21x1xf32>, tensor<1x1x1xf32>) -> tensor<13x21x1xf32>
}
// -----

func.func @test_max(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x1xf32> {
%0 = "onnx.Max"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xf32>
"func.return"(%0) : (tensor<13x21x1xf32>) -> ()
// CHECK-LABEL: func @test_max
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xf32>, [[PARAM_1_:%.+]]: tensor<13x21x1xf32>) -> tensor<13x21x1xf32> {
// CHECK-NEXT: [[VAR_1_:%.+]] = "tosa.maximum"([[PARAM_0_]], [[PARAM_1_]]) : (tensor<13x21x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xf32>
}

// -----

func.func @test_max_broadcast(%arg0: tensor<13x21x1xf32>, %arg1: tensor<1xf32>) -> tensor<13x21x1xf32> {
%0 = "onnx.Max"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<1xf32>) -> tensor<13x21x1xf32>
"func.return"(%0) : (tensor<13x21x1xf32>) -> ()
// CHECK-LABEL: func @test_max_broadcast
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xf32>, [[PARAM_1_:%.+]]: tensor<1xf32>) -> tensor<13x21x1xf32> {
// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.reshape"([[PARAM_1_]]) <{new_shape = array<i64: 1, 1, 1>}> : (tensor<1xf32>) -> tensor<1x1x1xf32>
// CHECK-NEXT: [[VAR_1_:%.+]] = "tosa.maximum"([[PARAM_0_]], [[VAR_0_]]) : (tensor<13x21x1xf32>, tensor<1x1x1xf32>) -> tensor<13x21x1xf32>
}

0 comments on commit f03513d

Please sign in to comment.