forked from onnx/onnx-mlir
-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add ONNX to TOSA lowering for onnx.Tile
* Lowering can only be done if the input is ranked * The repetitions need to be constant
- Loading branch information
Showing
5 changed files
with
131 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
//===------------------------ Tile.cpp - Tile Op --------------------------===// | ||
// | ||
// Copyright (c) 2023 Advanced Micro Devices, Inc. | ||
// | ||
// ============================================================================= | ||
// | ||
// This file lowers ONNX TileOp to TOSA dialect. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp" | ||
#include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp" | ||
#include "src/Dialect/ONNX/ONNXOps.hpp" | ||
|
||
#include <mlir/Dialect/Tosa/IR/TosaOps.h> | ||
#include <mlir/IR/BuiltinAttributes.h> | ||
#include <mlir/IR/BuiltinTypeInterfaces.h> | ||
#include <mlir/Transforms/DialectConversion.h> | ||
|
||
#include <llvm/ADT/SmallVector.h> | ||
#include <llvm/Support/Casting.h> | ||
|
||
#include <cstdint> | ||
|
||
using namespace mlir; | ||
|
||
namespace onnx_mlir { | ||
|
||
namespace { | ||
|
||
class ONNXTileLoweringToTOSA : public OpConversionPattern<ONNXTileOp> { | ||
public: | ||
using OpConversionPattern::OpConversionPattern; | ||
using OpAdaptor = typename ONNXTileOp::Adaptor; | ||
LogicalResult matchAndRewrite(ONNXTileOp op, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
|
||
auto inputType = adaptor.getInput().getType(); | ||
if (!onnx_mlir::isRankedShapedType(inputType)) { | ||
return rewriter.notifyMatchFailure( | ||
op, "input is not a ranked shaped tensor"); | ||
} | ||
|
||
int64_t inputRank = onnx_mlir::getRank(inputType); | ||
auto newResultElementType = cast<ShapedType>(inputType).getElementType(); | ||
Type newOutputType = RankedTensorType::get( | ||
llvm::SmallVector<int64_t>(inputRank, ShapedType::kDynamic), | ||
newResultElementType); | ||
|
||
// Create the attribute for the repetitions | ||
Value reps = adaptor.getRepeats(); | ||
auto repsConstant = | ||
dyn_cast_or_null<mlir::tosa::ConstOp>(reps.getDefiningOp()); | ||
if (!repsConstant) { | ||
return rewriter.notifyMatchFailure( | ||
op, "onnx.tile can only be lowered with constant repetitions"); | ||
} | ||
auto denseReps = repsConstant->getAttrOfType<DenseElementsAttr>("value"); | ||
llvm::SmallVector<int64_t> vals; | ||
for (auto val : denseReps.getValues<int64_t>()) { | ||
vals.push_back(val); | ||
} | ||
auto newReps = rewriter.getDenseI64ArrayAttr(vals); | ||
|
||
tosa::CreateReplaceOpAndInfer<mlir::tosa::TileOp>( | ||
rewriter, op, newOutputType, adaptor.getInput(), newReps); | ||
return success(); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
void populateLoweringONNXTileOpToTOSAPattern(ConversionTarget & /*target*/, | ||
RewritePatternSet &patterns, TypeConverter & /*typeConverter*/, | ||
MLIRContext *ctx) { | ||
patterns.insert<ONNXTileLoweringToTOSA>(ctx); | ||
} | ||
|
||
} // namespace onnx_mlir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
// RUN: onnx-mlir-opt --convert-onnx-to-tosa -cse %s -split-input-file | FileCheck %s | ||
|
||
func.func @test_tile(%arg0 : tensor<5x5x1x32xf32>) -> tensor<5x10x30x32xf32> { | ||
%const = onnx.Constant dense<[1, 2, 30, 1]> : tensor<4xi64> | ||
%tile = "onnx.Tile"(%arg0, %const) : (tensor<5x5x1x32xf32>, tensor<4xi64>) -> tensor<5x10x30x32xf32> | ||
"func.return"(%tile) : (tensor<5x10x30x32xf32>) -> () | ||
// CHECK-LABEL: test_tile | ||
// CHECK: tosa.tile{{.*}} <{multiples = array<i64: 1, 2, 30, 1>}> : (tensor<5x5x1x32xf32>) -> tensor<5x10x30x32xf32> | ||
} | ||
|
||
// ----- | ||
|
||
func.func @test_tile_dynamic_shape(%arg0 : tensor<5x5x?x32xf32>) -> tensor<5x10x?x32xf32> { | ||
%const = onnx.Constant dense<[1, 2, 30, 1]> : tensor<4xi64> | ||
%tile = "onnx.Tile"(%arg0, %const) : (tensor<5x5x?x32xf32>, tensor<4xi64>) -> tensor<5x10x?x32xf32> | ||
"func.return"(%tile) : (tensor<5x10x?x32xf32>) -> () | ||
// CHECK-LABEL: test_tile_dynamic_shape | ||
// CHECK: tosa.tile{{.*}} <{multiples = array<i64: 1, 2, 30, 1>}> : (tensor<5x5x?x32xf32>) -> tensor<5x10x?x32xf32> | ||
} | ||
|
||
// ----- | ||
|
||
func.func @test_tile_input_not_ranked(%arg0 : tensor<*xf32>) -> tensor<*xf32> { | ||
%const = onnx.Constant dense<[1, 2, 30, 1]> : tensor<4xi64> | ||
%tile = "onnx.Tile"(%arg0, %const) : (tensor<*xf32>, tensor<4xi64>) -> tensor<*xf32> | ||
"func.return"(%tile) : (tensor<*xf32>) -> () | ||
// CHECK-LABEL: test_tile_input_not_ranked | ||
// CHECK-NOT: tosa.tile | ||
} | ||
|
||
// ----- | ||
|
||
func.func @test_tile_non_constant_reps(%arg0 : tensor<5x5x1x32xf32>, %arg1 : tensor<4xi64>) -> tensor<*xf32> { | ||
%tile = "onnx.Tile"(%arg0, %arg1) : (tensor<5x5x1x32xf32>, tensor<4xi64>) -> tensor<*xf32> | ||
"func.return"(%tile) : (tensor<*xf32>) -> () | ||
// CHECK-LABEL: test_tile_non_constant_reps | ||
// CHECK-NOT: tosa.tile | ||
} |