From 300c98a396eeaa516e37ca616b129690f39148c4 Mon Sep 17 00:00:00 2001 From: Ehsan Nadjaran Toosi Date: Thu, 4 Jan 2024 16:35:47 +0000 Subject: [PATCH] feat: legalize onnx.max and onnx.min to tosa.maximum and tosa.minimum --- .../ONNXToTOSA/Math/Elementwise.cpp | 6 ++- .../onnx_to_tosa/Math/Elementwise.mlir | 43 ++++++++++++++++++- 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp b/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp index 24a6bfa67d..df4b2aa912 100644 --- a/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp +++ b/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp @@ -447,7 +447,11 @@ static void populateLoweringONNXElementwiseBinaryTemplateOpToTOSAPattern( ONNXBinaryElementwiseOpLoweringToTOSA, ONNXBinaryElementwiseOpLoweringToTOSA>(typeConverter, ctx); + IsFloat>, + ONNXBinaryElementwiseOpLoweringToTOSA, + ONNXBinaryElementwiseOpLoweringToTOSA>(typeConverter, ctx); } static void populateLoweringONNXElementwiseUnaryTemplateOpToTOSAPattern( diff --git a/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir b/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir index 51b859e386..61f87b7173 100644 --- a/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir +++ b/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir @@ -555,4 +555,45 @@ func.func @test_bitwise_or_broadcast(%arg0: tensor<13x21x1xi64>, %arg1: tensor<1 // CHECK: [[VAR_0_:%.+]] = "tosa.reshape"([[PARAM_1_]]) <{new_shape = array}> : (tensor<1xi64>) -> tensor<1x1x1xi64> // CHECK: [[VAR_1_:%.+]] = "tosa.bitwise_or"([[PARAM_0_]], [[VAR_0_]]) : (tensor<13x21x1xi64>, tensor<1x1x1xi64>) -> tensor<13x21x1xi64> // CHECK: return [[VAR_1_]] : tensor<13x21x1xi64> -} \ No newline at end of file +} + +// ----- + +func.func @test_min(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x1xf32> { + %0 = "onnx.Min"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xf32> + "func.return"(%0) : (tensor<13x21x1xf32>) -> () +// CHECK-LABEL: func @test_min +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xf32>, [[PARAM_1_:%.+]]: tensor<13x21x1xf32>) -> tensor<13x21x1xf32> { +// CHECK-NEXT: [[VAR_1_:%.+]] = "tosa.minimum"([[PARAM_0_]], [[PARAM_1_]]) : (tensor<13x21x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xf32> +} + +// ----- + +func.func @test_min_broadcast(%arg0: tensor<13x21x1xf32>, %arg1: tensor<1xf32>) -> tensor<13x21x1xf32> { + %0 = "onnx.Min"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<1xf32>) -> tensor<13x21x1xf32> + "func.return"(%0) : (tensor<13x21x1xf32>) -> () +// CHECK-LABEL: func @test_min_broadcast +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xf32>, [[PARAM_1_:%.+]]: tensor<1xf32>) -> tensor<13x21x1xf32> { +// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.reshape"([[PARAM_1_]]) <{new_shape = array}> : (tensor<1xf32>) -> tensor<1x1x1xf32> +// CHECK-NEXT: [[VAR_1_:%.+]] = "tosa.minimum"([[PARAM_0_]], [[VAR_0_]]) : (tensor<13x21x1xf32>, tensor<1x1x1xf32>) -> tensor<13x21x1xf32> +} +// ----- + +func.func @test_max(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x1xf32> { + %0 = "onnx.Max"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xf32> + "func.return"(%0) : (tensor<13x21x1xf32>) -> () +// CHECK-LABEL: func @test_max +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xf32>, [[PARAM_1_:%.+]]: tensor<13x21x1xf32>) -> tensor<13x21x1xf32> { +// CHECK-NEXT: [[VAR_1_:%.+]] = "tosa.maximum"([[PARAM_0_]], [[PARAM_1_]]) : (tensor<13x21x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xf32> +} + +// ----- + +func.func @test_max_broadcast(%arg0: tensor<13x21x1xf32>, %arg1: tensor<1xf32>) -> tensor<13x21x1xf32> { + %0 = "onnx.Max"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<1xf32>) -> tensor<13x21x1xf32> + "func.return"(%0) : (tensor<13x21x1xf32>) -> () +// CHECK-LABEL: func @test_max_broadcast +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xf32>, [[PARAM_1_:%.+]]: tensor<1xf32>) -> tensor<13x21x1xf32> { +// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.reshape"([[PARAM_1_]]) <{new_shape = array}> : (tensor<1xf32>) -> tensor<1x1x1xf32> +// CHECK-NEXT: [[VAR_1_:%.+]] = "tosa.maximum"([[PARAM_0_]], [[VAR_0_]]) : (tensor<13x21x1xf32>, tensor<1x1x1xf32>) -> tensor<13x21x1xf32> +}