Skip to content

Commit

Permalink
Merge pull request #27 from Xilinx/tiagot.onnx_to_tosa_pow
Browse files Browse the repository at this point in the history
feat(ONNXTOTosa): legalization for pow.
  • Loading branch information
ttjost authored Jan 4, 2024
2 parents 1a00015 + a320dec commit 4db5ac4
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 19 deletions.
10 changes: 10 additions & 0 deletions src/Conversion/ONNXToTOSA/DialectBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,16 @@ template Value TosaBuilder::binaryOp<mlir::tosa::AddOp>(
template Value TosaBuilder::binaryOp<mlir::tosa::SubOp>(
mlir::Value &lhs, mlir::Value &rhs);

template Value TosaBuilder::binaryOp<mlir::tosa::PowOp>(
mlir::Value &lhs, mlir::Value &rhs);

Value TosaBuilder::sqrt(mlir::Value &input) {
auto inputType = input.getType().cast<ShapedType>();
auto oneHalf = this->getSplattedConst(
0.5, inputType.getShape(), inputType.getElementType());
return this->binaryOp<mlir::tosa::PowOp>(input, oneHalf);
}

static bool containsNonZero(llvm::SmallVectorImpl<int64_t> &values) {
return llvm::any_of(values, [](int64_t value) { return value != 0; });
}
Expand Down
1 change: 1 addition & 0 deletions src/Conversion/ONNXToTOSA/DialectBuilder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ struct TosaBuilder : DialectBuilder {
llvm::ArrayRef<int64_t> start);
mlir::Value reshape(mlir::Value &value, llvm::ArrayRef<int64_t> shape);
mlir::Value reciprocal(mlir::Value &input);
mlir::Value sqrt(mlir::Value &input);

/// When using window based ops like maxpool or conv2d, we sometimes have
/// unused values at the end of a spatial dimension. TOSA does not allow that,
Expand Down
73 changes: 54 additions & 19 deletions src/Conversion/ONNXToTOSA/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,17 @@ struct IsInt {
}
};

struct IsFloat {
static LogicalResult checkType(
ConversionPatternRewriter &rewriter, Type scalarType, Operation *op) {
if (!isTOSAFloat(scalarType)) {
return rewriter.notifyMatchFailure(
op, "this operation only supports float types");
}
return success();
}
};

struct IsBool {
static LogicalResult checkType(
ConversionPatternRewriter &rewriter, Type scalarType, Operation *op) {
Expand All @@ -90,14 +101,14 @@ struct IsBool {
}
};

template <typename OpAdaptorT>
template <typename OpAdaptorT, typename TypeChecker>
LogicalResult checkBasicTosaRequirementsForBinaryOps(
ConversionPatternRewriter &rewriter, Operation *op, OpAdaptorT adaptor,
Type resultType) {
Value lhs = adaptor.getA();
Value lhs = adaptor.getOperands()[0];
auto lhsType = lhs.getType().dyn_cast<TensorType>();

Value rhs = adaptor.getB();
Value rhs = adaptor.getOperands()[1];
auto rhsType = rhs.getType().dyn_cast<TensorType>();

auto resultTensorType = resultType.dyn_cast<TensorType>();
Expand All @@ -107,8 +118,8 @@ LogicalResult checkBasicTosaRequirementsForBinaryOps(

Type resultElementType = resultTensorType.getElementType();

if (!resultElementType.isIntOrFloat()) {
return rewriter.notifyMatchFailure(op, "only int and float are supported");
if (failed(TypeChecker::checkType(rewriter, resultElementType, op))) {
return failure();
}

return success();
Expand Down Expand Up @@ -150,7 +161,7 @@ class ONNXElementwiseUnaryOpLoweringToTOSA
}
};

template <typename ONNXOpT, typename TosaOpT>
template <typename ONNXOpT, typename TosaOpT, typename TypeChecker>
class ONNXBinaryElementwiseOpLoweringToTOSA
: public OpConversionPattern<ONNXOpT> {
public:
Expand All @@ -159,13 +170,13 @@ class ONNXBinaryElementwiseOpLoweringToTOSA
LogicalResult matchAndRewrite(ONNXOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

if (failed(checkBasicTosaRequirementsForBinaryOps<OpAdaptor>(
if (failed(checkBasicTosaRequirementsForBinaryOps<OpAdaptor, TypeChecker>(
rewriter, op, adaptor, op.getResult().getType())))
return failure();

auto loc = op.getLoc();
Value lhs = adaptor.getA();
Value rhs = adaptor.getB();
Value lhs = adaptor.getOperands()[0];
Value rhs = adaptor.getOperands()[1];

if (TosaOpT::template hasTrait<
mlir::OpTrait::ResultsBroadcastableShape>()) {
Expand Down Expand Up @@ -194,7 +205,7 @@ class ONNXMulOpLoweringToTosa : public OpConversionPattern<ONNXMulOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(ONNXMulOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(checkBasicTosaRequirementsForBinaryOps<OpAdaptor>(
if (failed(checkBasicTosaRequirementsForBinaryOps<OpAdaptor, IsIntOrFloat>(
rewriter, op, adaptor, op.getResult().getType())))
return failure();

Expand Down Expand Up @@ -355,6 +366,26 @@ class ONNXDivOpLoweringToTOSA : public OpConversionPattern<ONNXDivOp> {
}
};

class ONNXSqrtOpLoweringToTOSA : public OpConversionPattern<ONNXSqrtOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(ONNXSqrtOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto resultTensorType = op.getResult().getType().cast<TensorType>();
if (failed(IsFloat::checkType(
rewriter, resultTensorType.getElementType(), op))) {
return failure();
}

Value input = op.getX();
TosaBuilder tosaBuilder(rewriter, op->getLoc());
Value sqrtOp = tosaBuilder.sqrt(input);
rewriter.replaceOp(op, {sqrtOp});
return success();
}
};

class ONNXHardSigmoidOpLoweringToTOSA
: public OpConversionPattern<ONNXHardSigmoidOp> {
public:
Expand Down Expand Up @@ -404,15 +435,19 @@ static void populateLoweringONNXElementwiseBinaryTemplateOpToTOSAPattern(
RewritePatternSet &patterns, TypeConverter &typeConverter,
MLIRContext *ctx) {
patterns.insert<ONNXBinaryElementwiseOpLoweringToTOSA<ONNXAndOp,
mlir::tosa::LogicalAndOp>,
mlir::tosa::LogicalAndOp, IsBool>,
ONNXBinaryElementwiseOpLoweringToTOSA<ONNXBitwiseAndOp,
mlir::tosa::BitwiseAndOp>,
ONNXBinaryElementwiseOpLoweringToTOSA<ONNXOrOp, mlir::tosa::LogicalOrOp>,
mlir::tosa::BitwiseAndOp, IsInt>,
ONNXBinaryElementwiseOpLoweringToTOSA<ONNXOrOp, mlir::tosa::LogicalOrOp,
IsBool>,
ONNXBinaryElementwiseOpLoweringToTOSA<ONNXBitwiseOrOp,
mlir::tosa::BitwiseOrOp>,
ONNXBinaryElementwiseOpLoweringToTOSA<ONNXAddOp, mlir::tosa::AddOp>,
ONNXBinaryElementwiseOpLoweringToTOSA<ONNXSubOp, mlir::tosa::SubOp>>(
typeConverter, ctx);
mlir::tosa::BitwiseOrOp, IsInt>,
ONNXBinaryElementwiseOpLoweringToTOSA<ONNXAddOp, mlir::tosa::AddOp,
IsIntOrFloat>,
ONNXBinaryElementwiseOpLoweringToTOSA<ONNXSubOp, mlir::tosa::SubOp,
IsIntOrFloat>,
ONNXBinaryElementwiseOpLoweringToTOSA<ONNXPowOp, mlir::tosa::PowOp,
IsFloat>>(typeConverter, ctx);
}

static void populateLoweringONNXElementwiseUnaryTemplateOpToTOSAPattern(
Expand Down Expand Up @@ -449,8 +484,8 @@ void populateLoweringONNXElementwiseOpToTOSAPattern(ConversionTarget &target,
MLIRContext *ctx) {
patterns.insert<ONNXReluOpLoweringToTOSA, ONNXLeakyReluOpLoweringToTOSA,
ONNXMulOpLoweringToTosa, ONNXClipOpLoweringToTOSA,
ONNXDivOpLoweringToTOSA, ONNXHardSigmoidOpLoweringToTOSA>(
typeConverter, ctx);
ONNXDivOpLoweringToTOSA, ONNXHardSigmoidOpLoweringToTOSA,
ONNXSqrtOpLoweringToTOSA>(typeConverter, ctx);

populateLoweringONNXElementwiseBinaryTemplateOpToTOSAPattern(
patterns, typeConverter, ctx);
Expand Down
32 changes: 32 additions & 0 deletions test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,38 @@ func.func @test_div_decomposed_broadcast(%arg0: tensor<13x21x1xf32>, %arg1: tens

// -----

func.func @test_pow(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x1xf32> {
%0 = "onnx.Pow"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xf32>
"func.return"(%0) : (tensor<13x21x1xf32>) -> ()
// CHECK-LABEL: func @test_pow
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xf32>, [[PARAM_1_:%.+]]: tensor<13x21x1xf32>) -> tensor<13x21x1xf32> {
// CHECK-NEXT: [[VAR_1_:%.+]] = "tosa.pow"([[PARAM_0_]], [[PARAM_1_]]) : (tensor<13x21x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xf32>
}

func.func @test_pow_broadcast(%arg0: tensor<13x21x1xf32>, %arg1: tensor<1xf32>) -> tensor<13x21x1xf32> {
%0 = "onnx.Pow"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<1xf32>) -> tensor<13x21x1xf32>
"func.return"(%0) : (tensor<13x21x1xf32>) -> ()
// CHECK-LABEL: func @test_pow_broadcast
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xf32>, [[PARAM_1_:%.+]]: tensor<1xf32>) -> tensor<13x21x1xf32> {
// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.reshape"([[PARAM_1_]]) <{new_shape = array<i64: 1, 1, 1>}> : (tensor<1xf32>) -> tensor<1x1x1xf32>
// CHECK-NEXT: [[VAR_1_:%.+]] = "tosa.pow"([[PARAM_0_]], [[VAR_0_]]) : (tensor<13x21x1xf32>, tensor<1x1x1xf32>) -> tensor<13x21x1xf32>
}

// -----

func.func @test_sqrt(%arg0: tensor<3xf32>) -> tensor<3xf32> {
%0 = "onnx.Sqrt"(%arg0) : (tensor<3xf32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
// CHECK-LABEL: func @test_sqrt
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3xf32>) -> tensor<3xf32>
// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<5.000000e-01> : tensor<3xf32>}> : () -> tensor<3xf32>
// CHECK-NEXT: [[VAR_1_:%.+]] = "tosa.pow"([[PARAM_0_]], [[VAR_0_]]) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
// CHECK-NEXT: return [[VAR_1_]] : tensor<3xf32>
// CHECK-NEXT: }
}

// -----

func.func @test_abs_i32(%arg0: tensor<3xi32>) -> tensor<3xi32> {
%0 = "onnx.Abs"(%arg0) : (tensor<3xi32>) -> tensor<3xi32>
return %0 : tensor<3xi32>
Expand Down

0 comments on commit 4db5ac4

Please sign in to comment.