Skip to content

Commit

Permalink
Merge pull request #28 from Xilinx/tiagot.legalize_sqrt
Browse files Browse the repository at this point in the history
feat(ONNXTOTosa): legalization sqrt.
  • Loading branch information
ttjost authored Jan 4, 2024
2 parents 97921ba + 71020ac commit a320dec
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 2 deletions.
7 changes: 7 additions & 0 deletions src/Conversion/ONNXToTOSA/DialectBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,13 @@ template Value TosaBuilder::binaryOp<mlir::tosa::SubOp>(
template Value TosaBuilder::binaryOp<mlir::tosa::PowOp>(
mlir::Value &lhs, mlir::Value &rhs);

Value TosaBuilder::sqrt(mlir::Value &input) {
auto inputType = input.getType().cast<ShapedType>();
auto oneHalf = this->getSplattedConst(
0.5, inputType.getShape(), inputType.getElementType());
return this->binaryOp<mlir::tosa::PowOp>(input, oneHalf);
}

static bool containsNonZero(llvm::SmallVectorImpl<int64_t> &values) {
return llvm::any_of(values, [](int64_t value) { return value != 0; });
}
Expand Down
1 change: 1 addition & 0 deletions src/Conversion/ONNXToTOSA/DialectBuilder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ struct TosaBuilder : DialectBuilder {
llvm::ArrayRef<int64_t> start);
mlir::Value reshape(mlir::Value &value, llvm::ArrayRef<int64_t> shape);
mlir::Value reciprocal(mlir::Value &input);
mlir::Value sqrt(mlir::Value &input);

/// When using window based ops like maxpool or conv2d, we sometimes have
/// unused values at the end of a spatial dimension. TOSA does not allow that,
Expand Down
24 changes: 22 additions & 2 deletions src/Conversion/ONNXToTOSA/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,26 @@ class ONNXDivOpLoweringToTOSA : public OpConversionPattern<ONNXDivOp> {
}
};

class ONNXSqrtOpLoweringToTOSA : public OpConversionPattern<ONNXSqrtOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(ONNXSqrtOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto resultTensorType = op.getResult().getType().cast<TensorType>();
if (failed(IsFloat::checkType(
rewriter, resultTensorType.getElementType(), op))) {
return failure();
}

Value input = op.getX();
TosaBuilder tosaBuilder(rewriter, op->getLoc());
Value sqrtOp = tosaBuilder.sqrt(input);
rewriter.replaceOp(op, {sqrtOp});
return success();
}
};

class ONNXHardSigmoidOpLoweringToTOSA
: public OpConversionPattern<ONNXHardSigmoidOp> {
public:
Expand Down Expand Up @@ -464,8 +484,8 @@ void populateLoweringONNXElementwiseOpToTOSAPattern(ConversionTarget &target,
MLIRContext *ctx) {
patterns.insert<ONNXReluOpLoweringToTOSA, ONNXLeakyReluOpLoweringToTOSA,
ONNXMulOpLoweringToTosa, ONNXClipOpLoweringToTOSA,
ONNXDivOpLoweringToTOSA, ONNXHardSigmoidOpLoweringToTOSA>(
typeConverter, ctx);
ONNXDivOpLoweringToTOSA, ONNXHardSigmoidOpLoweringToTOSA,
ONNXSqrtOpLoweringToTOSA>(typeConverter, ctx);

populateLoweringONNXElementwiseBinaryTemplateOpToTOSAPattern(
patterns, typeConverter, ctx);
Expand Down
13 changes: 13 additions & 0 deletions test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,19 @@ func.func @test_pow_broadcast(%arg0: tensor<13x21x1xf32>, %arg1: tensor<1xf32>)

// -----

func.func @test_sqrt(%arg0: tensor<3xf32>) -> tensor<3xf32> {
%0 = "onnx.Sqrt"(%arg0) : (tensor<3xf32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
// CHECK-LABEL: func @test_sqrt
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3xf32>) -> tensor<3xf32>
// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<5.000000e-01> : tensor<3xf32>}> : () -> tensor<3xf32>
// CHECK-NEXT: [[VAR_1_:%.+]] = "tosa.pow"([[PARAM_0_]], [[VAR_0_]]) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
// CHECK-NEXT: return [[VAR_1_]] : tensor<3xf32>
// CHECK-NEXT: }
}

// -----

func.func @test_abs_i32(%arg0: tensor<3xi32>) -> tensor<3xi32> {
%0 = "onnx.Abs"(%arg0) : (tensor<3xi32>) -> tensor<3xi32>
return %0 : tensor<3xi32>
Expand Down

0 comments on commit a320dec

Please sign in to comment.