diff --git a/src/Conversion/ONNXToTOSA/CMakeLists.txt b/src/Conversion/ONNXToTOSA/CMakeLists.txt index 93efe4f27a4..da767224add 100644 --- a/src/Conversion/ONNXToTOSA/CMakeLists.txt +++ b/src/Conversion/ONNXToTOSA/CMakeLists.txt @@ -26,8 +26,9 @@ add_onnx_mlir_library(OMONNXToTOSA Tensor/Reshape.cpp Tensor/Resize.cpp Tensor/Slice.cpp - Tensor/Transpose.cpp Tensor/Squeeze.cpp + Tensor/Tile.cpp + Tensor/Transpose.cpp Flow/EntryPoint.cpp diff --git a/src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp b/src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp index 2b28aebf782..b7588a60ca0 100644 --- a/src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp +++ b/src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp @@ -54,10 +54,12 @@ void populateONNXToTOSAConversionPattern(ConversionTarget &target, populateLoweringONNXPadOpToTOSAPattern(target, patterns, typeConverter, ctx); populateLoweringONNXSliceOpToTOSAPattern( target, patterns, typeConverter, ctx); - populateLoweringONNXTransposeOpToTOSAPattern( - target, patterns, typeConverter, ctx); populateLoweringONNXSqueezeOpToTOSAPattern( target, patterns, typeConverter, ctx); + populateLoweringONNXTileOpToTOSAPattern( + target, patterns, typeConverter, ctx); + populateLoweringONNXTransposeOpToTOSAPattern( + target, patterns, typeConverter, ctx); // NN populateLoweringONNXMaxPoolSingleOutOpToTOSAPattern( target, patterns, typeConverter, ctx); diff --git a/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp b/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp index a9bfa981277..19396c82268 100644 --- a/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp +++ b/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp @@ -373,10 +373,12 @@ void populateLoweringONNXFlattenOpToTOSAPattern(mlir::ConversionTarget &, mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *); void populateLoweringONNXSliceOpToTOSAPattern(mlir::ConversionTarget &, mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *); -void populateLoweringONNXTransposeOpToTOSAPattern(mlir::ConversionTarget &, - mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *); void populateLoweringONNXSqueezeOpToTOSAPattern(mlir::ConversionTarget &, mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *); +void populateLoweringONNXTileOpToTOSAPattern(mlir::ConversionTarget &, + mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *); +void populateLoweringONNXTransposeOpToTOSAPattern(mlir::ConversionTarget &, + mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *); // 'Flow' directory methods: void populateLoweringONNXEntryPointOpToTOSAPattern(mlir::ConversionTarget &, mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *); diff --git a/src/Conversion/ONNXToTOSA/Tensor/Tile.cpp b/src/Conversion/ONNXToTOSA/Tensor/Tile.cpp new file mode 100644 index 00000000000..6488bf0e9a6 --- /dev/null +++ b/src/Conversion/ONNXToTOSA/Tensor/Tile.cpp @@ -0,0 +1,83 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===------------------------ Tile.cpp - Tile Op --------------------------===// +// +// Copyright (c) 2023 Advanced Micro Devices, Inc. +// +// ============================================================================= +// +// This file lowers ONNX TileOp to TOSA dialect. +// +//===----------------------------------------------------------------------===// + +#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp" +#include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +#include +#include +#include +#include + +#include +#include + +#include + +using namespace mlir; + +namespace onnx_mlir { + +namespace { + +class ONNXTileLoweringToTOSA : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename ONNXTileOp::Adaptor; + LogicalResult matchAndRewrite(ONNXTileOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto inputType = adaptor.getInput().getType(); + if (!onnx_mlir::isRankedShapedType(inputType)) { + return rewriter.notifyMatchFailure( + op, "input is not a ranked shaped tensor"); + } + + int64_t inputRank = onnx_mlir::getRank(inputType); + auto newResultElementType = cast(inputType).getElementType(); + Type newOutputType = RankedTensorType::get( + llvm::SmallVector(inputRank, ShapedType::kDynamic), + newResultElementType); + + // Create the attribute for the repetitions + Value reps = adaptor.getRepeats(); + auto repsConstant = + dyn_cast_or_null(reps.getDefiningOp()); + if (!repsConstant) { + return rewriter.notifyMatchFailure( + op, "onnx.tile can only be lowered with constant repetitions"); + } + auto denseReps = repsConstant->getAttrOfType("value"); + llvm::SmallVector vals; + for (auto val : denseReps.getValues()) { + vals.push_back(val); + } + auto newReps = rewriter.getDenseI64ArrayAttr(vals); + + tosa::CreateReplaceOpAndInfer( + rewriter, op, newOutputType, adaptor.getInput(), newReps); + return success(); + } +}; + +} // namespace + +void populateLoweringONNXTileOpToTOSAPattern(ConversionTarget & /*target*/, + RewritePatternSet &patterns, TypeConverter & /*typeConverter*/, + MLIRContext *ctx) { + patterns.insert(ctx); +} + +} // namespace onnx_mlir diff --git a/test/mlir/conversion/onnx_to_tosa/Tensor/Tile.mlir b/test/mlir/conversion/onnx_to_tosa/Tensor/Tile.mlir new file mode 100644 index 00000000000..f0d33f93b1e --- /dev/null +++ b/test/mlir/conversion/onnx_to_tosa/Tensor/Tile.mlir @@ -0,0 +1,38 @@ +// RUN: onnx-mlir-opt --convert-onnx-to-tosa -cse %s -split-input-file | FileCheck %s + +func.func @test_tile(%arg0 : tensor<5x5x1x32xf32>) -> tensor<5x10x30x32xf32> { + %const = onnx.Constant dense<[1, 2, 30, 1]> : tensor<4xi64> + %tile = "onnx.Tile"(%arg0, %const) : (tensor<5x5x1x32xf32>, tensor<4xi64>) -> tensor<5x10x30x32xf32> + "func.return"(%tile) : (tensor<5x10x30x32xf32>) -> () +// CHECK-LABEL: test_tile +// CHECK: tosa.tile{{.*}} <{multiples = array}> : (tensor<5x5x1x32xf32>) -> tensor<5x10x30x32xf32> +} + +// ----- + +func.func @test_tile_dynamic_shape(%arg0 : tensor<5x5x?x32xf32>) -> tensor<5x10x?x32xf32> { + %const = onnx.Constant dense<[1, 2, 30, 1]> : tensor<4xi64> + %tile = "onnx.Tile"(%arg0, %const) : (tensor<5x5x?x32xf32>, tensor<4xi64>) -> tensor<5x10x?x32xf32> + "func.return"(%tile) : (tensor<5x10x?x32xf32>) -> () +// CHECK-LABEL: test_tile_dynamic_shape +// CHECK: tosa.tile{{.*}} <{multiples = array}> : (tensor<5x5x?x32xf32>) -> tensor<5x10x?x32xf32> +} + +// ----- + +func.func @test_tile_input_not_ranked(%arg0 : tensor<*xf32>) -> tensor<*xf32> { + %const = onnx.Constant dense<[1, 2, 30, 1]> : tensor<4xi64> + %tile = "onnx.Tile"(%arg0, %const) : (tensor<*xf32>, tensor<4xi64>) -> tensor<*xf32> + "func.return"(%tile) : (tensor<*xf32>) -> () +// CHECK-LABEL: test_tile_input_not_ranked +// CHECK-NOT: tosa.tile +} + +// ----- + +func.func @test_tile_non_constant_reps(%arg0 : tensor<5x5x1x32xf32>, %arg1 : tensor<4xi64>) -> tensor<*xf32> { + %tile = "onnx.Tile"(%arg0, %arg1) : (tensor<5x5x1x32xf32>, tensor<4xi64>) -> tensor<*xf32> + "func.return"(%tile) : (tensor<*xf32>) -> () +// CHECK-LABEL: test_tile_non_constant_reps +// CHECK-NOT: tosa.tile +}