Skip to content

Commit

Permalink
Add ONNX to TOSA lowering for onnx.Tile
Browse files Browse the repository at this point in the history
* Lowering can only be done if the input is ranked
* The repetitions need to be constant
  • Loading branch information
TinaAMD committed Jan 8, 2024
1 parent a145ceb commit 5e60687
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 5 deletions.
3 changes: 2 additions & 1 deletion src/Conversion/ONNXToTOSA/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ add_onnx_mlir_library(OMONNXToTOSA
Tensor/Reshape.cpp
Tensor/Resize.cpp
Tensor/Slice.cpp
Tensor/Transpose.cpp
Tensor/Squeeze.cpp
Tensor/Tile.cpp
Tensor/Transpose.cpp
Flow/EntryPoint.cpp


Expand Down
6 changes: 4 additions & 2 deletions src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,12 @@ void populateONNXToTOSAConversionPattern(ConversionTarget &target,
populateLoweringONNXPadOpToTOSAPattern(target, patterns, typeConverter, ctx);
populateLoweringONNXSliceOpToTOSAPattern(
target, patterns, typeConverter, ctx);
populateLoweringONNXTransposeOpToTOSAPattern(
target, patterns, typeConverter, ctx);
populateLoweringONNXSqueezeOpToTOSAPattern(
target, patterns, typeConverter, ctx);
populateLoweringONNXTileOpToTOSAPattern(
target, patterns, typeConverter, ctx);
populateLoweringONNXTransposeOpToTOSAPattern(
target, patterns, typeConverter, ctx);
// NN
populateLoweringONNXMaxPoolSingleOutOpToTOSAPattern(
target, patterns, typeConverter, ctx);
Expand Down
6 changes: 4 additions & 2 deletions src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -373,10 +373,12 @@ void populateLoweringONNXFlattenOpToTOSAPattern(mlir::ConversionTarget &,
mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *);
void populateLoweringONNXSliceOpToTOSAPattern(mlir::ConversionTarget &,
mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *);
void populateLoweringONNXTransposeOpToTOSAPattern(mlir::ConversionTarget &,
mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *);
void populateLoweringONNXSqueezeOpToTOSAPattern(mlir::ConversionTarget &,
mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *);
void populateLoweringONNXTileOpToTOSAPattern(mlir::ConversionTarget &,
mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *);
void populateLoweringONNXTransposeOpToTOSAPattern(mlir::ConversionTarget &,
mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *);
// 'Flow' directory methods:
void populateLoweringONNXEntryPointOpToTOSAPattern(mlir::ConversionTarget &,
mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *);
Expand Down
83 changes: 83 additions & 0 deletions src/Conversion/ONNXToTOSA/Tensor/Tile.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//===------------------------ Tile.cpp - Tile Op --------------------------===//
//
// Copyright (c) 2023 Advanced Micro Devices, Inc.
//
// =============================================================================
//
// This file lowers ONNX TileOp to TOSA dialect.
//
//===----------------------------------------------------------------------===//

#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp"
#include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"

#include <mlir/Dialect/Tosa/IR/TosaOps.h>
#include <mlir/IR/BuiltinAttributes.h>
#include <mlir/IR/BuiltinTypeInterfaces.h>
#include <mlir/Transforms/DialectConversion.h>

#include <llvm/ADT/SmallVector.h>
#include <llvm/Support/Casting.h>

#include <cstdint>

using namespace mlir;

namespace onnx_mlir {

namespace {

class ONNXTileLoweringToTOSA : public OpConversionPattern<ONNXTileOp> {
public:
using OpConversionPattern::OpConversionPattern;
using OpAdaptor = typename ONNXTileOp::Adaptor;
LogicalResult matchAndRewrite(ONNXTileOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto inputType = adaptor.getInput().getType();
if (!onnx_mlir::isRankedShapedType(inputType)) {
return rewriter.notifyMatchFailure(
op, "input is not a ranked shaped tensor");
}

int64_t inputRank = onnx_mlir::getRank(inputType);
auto newResultElementType = cast<ShapedType>(inputType).getElementType();
Type newOutputType = RankedTensorType::get(
llvm::SmallVector<int64_t>(inputRank, ShapedType::kDynamic),
newResultElementType);

// Create the attribute for the repetitions
Value reps = adaptor.getRepeats();
auto repsConstant =
dyn_cast_or_null<mlir::tosa::ConstOp>(reps.getDefiningOp());
if (!repsConstant) {
return rewriter.notifyMatchFailure(
op, "onnx.tile can only be lowered with constant repetitions");
}
auto denseReps = repsConstant->getAttrOfType<DenseElementsAttr>("value");
llvm::SmallVector<int64_t> vals;
for (auto val : denseReps.getValues<int64_t>()) {
vals.push_back(val);
}
auto newReps = rewriter.getDenseI64ArrayAttr(vals);

tosa::CreateReplaceOpAndInfer<mlir::tosa::TileOp>(
rewriter, op, newOutputType, adaptor.getInput(), newReps);
return success();
}
};

} // namespace

void populateLoweringONNXTileOpToTOSAPattern(ConversionTarget & /*target*/,
RewritePatternSet &patterns, TypeConverter & /*typeConverter*/,
MLIRContext *ctx) {
patterns.insert<ONNXTileLoweringToTOSA>(ctx);
}

} // namespace onnx_mlir
38 changes: 38 additions & 0 deletions test/mlir/conversion/onnx_to_tosa/Tensor/Tile.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// RUN: onnx-mlir-opt --convert-onnx-to-tosa -cse %s -split-input-file | FileCheck %s

func.func @test_tile(%arg0 : tensor<5x5x1x32xf32>) -> tensor<5x10x30x32xf32> {
%const = onnx.Constant dense<[1, 2, 30, 1]> : tensor<4xi64>
%tile = "onnx.Tile"(%arg0, %const) : (tensor<5x5x1x32xf32>, tensor<4xi64>) -> tensor<5x10x30x32xf32>
"func.return"(%tile) : (tensor<5x10x30x32xf32>) -> ()
// CHECK-LABEL: test_tile
// CHECK: tosa.tile{{.*}} <{multiples = array<i64: 1, 2, 30, 1>}> : (tensor<5x5x1x32xf32>) -> tensor<5x10x30x32xf32>
}

// -----

func.func @test_tile_dynamic_shape(%arg0 : tensor<5x5x?x32xf32>) -> tensor<5x10x?x32xf32> {
%const = onnx.Constant dense<[1, 2, 30, 1]> : tensor<4xi64>
%tile = "onnx.Tile"(%arg0, %const) : (tensor<5x5x?x32xf32>, tensor<4xi64>) -> tensor<5x10x?x32xf32>
"func.return"(%tile) : (tensor<5x10x?x32xf32>) -> ()
// CHECK-LABEL: test_tile_dynamic_shape
// CHECK: tosa.tile{{.*}} <{multiples = array<i64: 1, 2, 30, 1>}> : (tensor<5x5x?x32xf32>) -> tensor<5x10x?x32xf32>
}

// -----

func.func @test_tile_input_not_ranked(%arg0 : tensor<*xf32>) -> tensor<*xf32> {
%const = onnx.Constant dense<[1, 2, 30, 1]> : tensor<4xi64>
%tile = "onnx.Tile"(%arg0, %const) : (tensor<*xf32>, tensor<4xi64>) -> tensor<*xf32>
"func.return"(%tile) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_tile_input_not_ranked
// CHECK-NOT: tosa.tile
}

// -----

func.func @test_tile_non_constant_reps(%arg0 : tensor<5x5x1x32xf32>, %arg1 : tensor<4xi64>) -> tensor<*xf32> {
%tile = "onnx.Tile"(%arg0, %arg1) : (tensor<5x5x1x32xf32>, tensor<4xi64>) -> tensor<*xf32>
"func.return"(%tile) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_tile_non_constant_reps
// CHECK-NOT: tosa.tile
}

0 comments on commit 5e60687

Please sign in to comment.