diff --git a/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp b/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp index c61ea0d881..fba39620ac 100644 --- a/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp +++ b/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp @@ -451,7 +451,11 @@ static void populateLoweringONNXElementwiseBinaryTemplateOpToTOSAPattern( ONNXBinaryElementwiseOpLoweringToTOSA, ONNXBinaryElementwiseOpLoweringToTOSA>(typeConverter, ctx); + IsFloat>, + ONNXBinaryElementwiseOpLoweringToTOSA, + ONNXBinaryElementwiseOpLoweringToTOSA>(typeConverter, ctx); } static void populateLoweringONNXElementwiseUnaryTemplateOpToTOSAPattern( diff --git a/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir b/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir index f5d73d3f3e..3d7b16f578 100644 --- a/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir +++ b/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir @@ -597,4 +597,45 @@ func.func @test_bitwise_xor_broadcast(%arg0: tensor<13x21x1xi64>, %arg1: tensor< // CHECK: [[VAR_0_:%.+]] = "tosa.reshape"([[PARAM_1_]]) <{new_shape = array}> : (tensor<1xi64>) -> tensor<1x1x1xi64> // CHECK: [[VAR_1_:%.+]] = "tosa.bitwise_xor"([[PARAM_0_]], [[VAR_0_]]) : (tensor<13x21x1xi64>, tensor<1x1x1xi64>) -> tensor<13x21x1xi64> // CHECK: return [[VAR_1_]] : tensor<13x21x1xi64> -} \ No newline at end of file +} + +// ----- + +func.func @test_min(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x1xf32> { + %0 = "onnx.Min"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xf32> + "func.return"(%0) : (tensor<13x21x1xf32>) -> () +// CHECK-LABEL: func @test_min +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xf32>, [[PARAM_1_:%.+]]: tensor<13x21x1xf32>) -> tensor<13x21x1xf32> { +// CHECK-NEXT: [[VAR_1_:%.+]] = "tosa.minimum"([[PARAM_0_]], [[PARAM_1_]]) : (tensor<13x21x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xf32> +} + +// ----- + +func.func @test_min_broadcast(%arg0: tensor<13x21x1xf32>, %arg1: tensor<1xf32>) -> tensor<13x21x1xf32> { + %0 = "onnx.Min"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<1xf32>) -> tensor<13x21x1xf32> + "func.return"(%0) : (tensor<13x21x1xf32>) -> () +// CHECK-LABEL: func @test_min_broadcast +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xf32>, [[PARAM_1_:%.+]]: tensor<1xf32>) -> tensor<13x21x1xf32> { +// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.reshape"([[PARAM_1_]]) <{new_shape = array}> : (tensor<1xf32>) -> tensor<1x1x1xf32> +// CHECK-NEXT: [[VAR_1_:%.+]] = "tosa.minimum"([[PARAM_0_]], [[VAR_0_]]) : (tensor<13x21x1xf32>, tensor<1x1x1xf32>) -> tensor<13x21x1xf32> +} +// ----- + +func.func @test_max(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x1xf32> { + %0 = "onnx.Max"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xf32> + "func.return"(%0) : (tensor<13x21x1xf32>) -> () +// CHECK-LABEL: func @test_max +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xf32>, [[PARAM_1_:%.+]]: tensor<13x21x1xf32>) -> tensor<13x21x1xf32> { +// CHECK-NEXT: [[VAR_1_:%.+]] = "tosa.maximum"([[PARAM_0_]], [[PARAM_1_]]) : (tensor<13x21x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xf32> +} + +// ----- + +func.func @test_max_broadcast(%arg0: tensor<13x21x1xf32>, %arg1: tensor<1xf32>) -> tensor<13x21x1xf32> { + %0 = "onnx.Max"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<1xf32>) -> tensor<13x21x1xf32> + "func.return"(%0) : (tensor<13x21x1xf32>) -> () +// CHECK-LABEL: func @test_max_broadcast +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xf32>, [[PARAM_1_:%.+]]: tensor<1xf32>) -> tensor<13x21x1xf32> { +// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.reshape"([[PARAM_1_]]) <{new_shape = array}> : (tensor<1xf32>) -> tensor<1x1x1xf32> +// CHECK-NEXT: [[VAR_1_:%.+]] = "tosa.maximum"([[PARAM_0_]], [[VAR_0_]]) : (tensor<13x21x1xf32>, tensor<1x1x1xf32>) -> tensor<13x21x1xf32> +}