Skip to content

Commit

Permalink
feat(ONNXTOTosa): legalization for onnx.Elu.
Browse files Browse the repository at this point in the history
  • Loading branch information
ttjost committed Jan 4, 2024
1 parent 4db5ac4 commit f24a4a0
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 5 deletions.
36 changes: 36 additions & 0 deletions src/Conversion/ONNXToTOSA/DialectBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ using namespace mlir;

namespace onnx_mlir {

template <typename T>
Value compareOp(mlir::PatternRewriter &rewriter, mlir::Location loc,
mlir::Value &lhs, mlir::Value &rhs) {
return tosa::CreateOpAndInfer<mlir::tosa::GreaterEqualOp>(
rewriter, loc, UnrankedTensorType::get(rewriter.getI1Type()), lhs, rhs);
}

template <typename T>
bool TosaBuilder::testNumberOfElementsMatch(
ArrayRef<T> vec, ArrayRef<int64_t> shape) {
Expand Down Expand Up @@ -233,6 +240,15 @@ Value TosaBuilder::reciprocal(mlir::Value &input) {
rewriter(), loc(), newValueType, input);
}

Value TosaBuilder::exp(mlir::Value &input) {
auto inputType = input.getType().cast<ShapedType>();
Type newValueType = RankedTensorType::get(
llvm::SmallVector<int64_t, 4>(inputType.getRank(), ShapedType::kDynamic),
inputType.getElementType());
return tosa::CreateOpAndInfer<mlir::tosa::ExpOp>(
rewriter(), loc(), newValueType, input);
}

template <typename T>
Value TosaBuilder::binaryOp(mlir::Value &lhs, mlir::Value &rhs) {
if (needsRankBroadcast({lhs, rhs})) {
Expand All @@ -256,6 +272,26 @@ template Value TosaBuilder::binaryOp<mlir::tosa::SubOp>(
template Value TosaBuilder::binaryOp<mlir::tosa::PowOp>(
mlir::Value &lhs, mlir::Value &rhs);

mlir::Value TosaBuilder::equal(mlir::Value &lhs, mlir::Value &rhs) {
return compareOp<mlir::tosa::EqualOp>(rewriter(), loc(), lhs, rhs);
}

mlir::Value TosaBuilder::greater(mlir::Value &lhs, mlir::Value &rhs) {
return compareOp<mlir::tosa::GreaterOp>(rewriter(), loc(), lhs, rhs);
}

mlir::Value TosaBuilder::greaterEqual(mlir::Value &lhs, mlir::Value &rhs) {
return compareOp<mlir::tosa::GreaterEqualOp>(rewriter(), loc(), lhs, rhs);
}

mlir::Value TosaBuilder::less(mlir::Value &lhs, mlir::Value &rhs) {
return this->greater(rhs, lhs);
}

mlir::Value TosaBuilder::lessEqual(mlir::Value &lhs, mlir::Value &rhs) {
return this->greaterEqual(rhs, lhs);
}

Value TosaBuilder::sqrt(mlir::Value &input) {
auto inputType = input.getType().cast<ShapedType>();
auto oneHalf = this->getSplattedConst(
Expand Down
47 changes: 42 additions & 5 deletions src/Conversion/ONNXToTOSA/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,15 +248,15 @@ static LogicalResult LegalizeFloatingPointPrelu(Operation *op,
TosaBuilder tosaBuilder(rewriter, loc);
Value constZero = tosaBuilder.getSplattedConst(0.0, outputType.getShape());

auto greaterEqual =
tosa::CreateOpAndInfer<mlir::tosa::GreaterEqualOp>(rewriter, op->getLoc(),
UnrankedTensorType::get(rewriter.getI1Type()), input, constZero);

auto mul = tosa::CreateOpAndInfer<mlir::tosa::MulOp>(rewriter, op->getLoc(),
outputType, input,
tosaBuilder.getSplattedConst(alpha, outputType.getShape()),
/*shift=*/0);

auto greaterEqual =
tosa::CreateOpAndInfer<mlir::tosa::GreaterEqualOp>(rewriter, op->getLoc(),
UnrankedTensorType::get(rewriter.getI1Type()), input, constZero);

tosa::CreateReplaceOpAndInfer<mlir::tosa::SelectOp>(
rewriter, op, outputType, greaterEqual, input, mul.getResult());

Expand Down Expand Up @@ -386,6 +386,43 @@ class ONNXSqrtOpLoweringToTOSA : public OpConversionPattern<ONNXSqrtOp> {
}
};

class ONNXEluOpLoweringToTOSA : public OpConversionPattern<ONNXEluOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(ONNXEluOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// ELU(x) = x if x >= 0
// alpha * (exp(x) - 1.) if x < 0

auto resultTensorType = op.getResult().getType().cast<TensorType>();
if (failed(IsFloat::checkType(
rewriter, resultTensorType.getElementType(), op))) {
return failure();
}

Value input = op.getX();

TosaBuilder tosaBuilder(rewriter, op->getLoc());

Value one = tosaBuilder.getSplattedConst(
1.0, resultTensorType.getShape(), resultTensorType.getElementType());
Value alpha =
tosaBuilder.getSplattedConst(adaptor.getAlpha().convertToDouble(),
resultTensorType.getShape(), resultTensorType.getElementType());
Value constZero = tosaBuilder.getSplattedConst(
0.0, resultTensorType.getShape(), resultTensorType.getElementType());

Value exp = tosaBuilder.exp(input);
Value expMinusOne = tosaBuilder.binaryOp<mlir::tosa::SubOp>(exp, one);
Value alphaTimesExpMinusOne = tosaBuilder.mul(expMinusOne, alpha);
Value greaterEqual = tosaBuilder.greaterEqual(input, constZero);

tosa::CreateReplaceOpAndInfer<mlir::tosa::SelectOp>(rewriter, op,
resultTensorType, greaterEqual, input, alphaTimesExpMinusOne);
return success();
}
};

class ONNXHardSigmoidOpLoweringToTOSA
: public OpConversionPattern<ONNXHardSigmoidOp> {
public:
Expand Down Expand Up @@ -485,7 +522,7 @@ void populateLoweringONNXElementwiseOpToTOSAPattern(ConversionTarget &target,
patterns.insert<ONNXReluOpLoweringToTOSA, ONNXLeakyReluOpLoweringToTOSA,
ONNXMulOpLoweringToTosa, ONNXClipOpLoweringToTOSA,
ONNXDivOpLoweringToTOSA, ONNXHardSigmoidOpLoweringToTOSA,
ONNXSqrtOpLoweringToTOSA>(typeConverter, ctx);
ONNXSqrtOpLoweringToTOSA, ONNXEluOpLoweringToTOSA>(typeConverter, ctx);

populateLoweringONNXElementwiseBinaryTemplateOpToTOSAPattern(
patterns, typeConverter, ctx);
Expand Down
36 changes: 36 additions & 0 deletions test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,42 @@ func.func @test_hardsigmoid_f16(%arg0: tensor<3xf16>) -> tensor<3xf16> {
// CHECK: [[VAR_4_:%.+]] = "tosa.mul"([[VAR_3_]], [[VAR_1_]]) <{shift = 0 : i32}> : (tensor<3xf16>, tensor<3xf16>) -> tensor<3xf16>
// CHECK: return [[VAR_4_]] : tensor<3xf16>
}

// -----

func.func @test_elu_f32(%arg0: tensor<3xf32>) -> tensor<3xf32> {
%0 = "onnx.Elu"(%arg0) {alpha = 0.166666672 : f32} : (tensor<3xf32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
// CHECK-LABEL: func.func @test_elu_f32
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3xf32>) -> tensor<3xf32> {
// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<3xf32>}> : () -> tensor<3xf32>
// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<0.166666672> : tensor<3xf32>}> : () -> tensor<3xf32>
// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<3xf32>}> : () -> tensor<3xf32>
// CHECK-DAG: [[VAR_3_:%.+]] = "tosa.exp"([[PARAM_0_]]) : (tensor<3xf32>) -> tensor<3xf32>
// CHECK: [[VAR_4_:%.+]] = "tosa.sub"([[VAR_3_]], [[VAR_0_]]) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
// CHECK-DAG: [[VAR_5_:%.+]] = "tosa.mul"([[VAR_4_]], [[VAR_1_]]) <{shift = 0 : i32}> : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
// CHECK-DAG: [[VAR_6_:%.+]] = "tosa.greater_equal"([[PARAM_0_]], [[VAR_2_]]) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1>
// CHECK: [[VAR_7_:%.+]] = "tosa.select"([[VAR_6_]], [[PARAM_0_]], [[VAR_5_]]) : (tensor<3xi1>, tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
// CHECK: return [[VAR_7_]]
}

func.func @test_elu_f16(%arg0: tensor<3xf16>) -> tensor<3xf16> {
%0 = "onnx.Elu"(%arg0) {alpha = 0.166666672 : f32, beta = 5.000000e-01 : f32} : (tensor<3xf16>) -> tensor<3xf16>
return %0 : tensor<3xf16>
// CHECK-LABEL: func.func @test_elu_f16
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3xf16>) -> tensor<3xf16> {
// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<3xf16>}> : () -> tensor<3xf16>
// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<1.666260e-01> : tensor<3xf16>}> : () -> tensor<3xf16>
// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<3xf16>}> : () -> tensor<3xf16>
// CHECK-DAG: [[VAR_3_:%.+]] = "tosa.exp"([[PARAM_0_]]) : (tensor<3xf16>) -> tensor<3xf16>
// CHECK: [[VAR_4_:%.+]] = "tosa.sub"([[VAR_3_]], [[VAR_0_]]) : (tensor<3xf16>, tensor<3xf16>) -> tensor<3xf16>
// CHECK-DAG: [[VAR_5_:%.+]] = "tosa.mul"([[VAR_4_]], [[VAR_1_]]) <{shift = 0 : i32}> : (tensor<3xf16>, tensor<3xf16>) -> tensor<3xf16>
// CHECK-DAG: [[VAR_6_:%.+]] = "tosa.greater_equal"([[PARAM_0_]], [[VAR_2_]]) : (tensor<3xf16>, tensor<3xf16>) -> tensor<3xi1>
// CHECK: [[VAR_7_:%.+]] = "tosa.select"([[VAR_6_]], [[PARAM_0_]], [[VAR_5_]]) : (tensor<3xi1>, tensor<3xf16>, tensor<3xf16>) -> tensor<3xf16>
// CHECK: return [[VAR_7_]]
}

// -----
// -----

func.func @test_and(%arg0: tensor<13x21x1xi1>, %arg1: tensor<13x21x1xi1>) -> tensor<13x21x1xi1> {
Expand Down

0 comments on commit f24a4a0

Please sign in to comment.