Skip to content

Commit

Permalink
Merge pull request #23 from Xilinx/tiagot.onnx_to_tosa_hardsigmoid
Browse files Browse the repository at this point in the history
ONNXToTosa: legalize ONNXHardSigmoid to tosa.
  • Loading branch information
mgehre-amd authored Jan 3, 2024
2 parents 4a5651c + aa45992 commit 0c6a798
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 4 deletions.
7 changes: 6 additions & 1 deletion src/Conversion/ONNXToTOSA/DialectBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,17 @@ Value TosaBuilder::getConst(ArrayRef<float> vec, ArrayRef<int64_t> shape) {
return constOp;
}

Value TosaBuilder::getSplattedConst(float val, llvm::ArrayRef<int64_t> shape) {
Value TosaBuilder::getSplattedConst(
float val, llvm::ArrayRef<int64_t> shape, std::optional<Type> dtype) {
auto constType = tosa::reduceAxisToOne(shape, rewriter().getF32Type());
auto constAttr = DenseElementsAttr::get(constType, val);

auto constOp =
rewriter().create<mlir::tosa::ConstOp>(loc(), constType, constAttr);

if (dtype)
return rewriter().createOrFold<mlir::tosa::CastOp>(
loc(), RankedTensorType::get(shape, *dtype), constOp);
return constOp;
}

Expand Down
6 changes: 4 additions & 2 deletions src/Conversion/ONNXToTOSA/DialectBuilder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,12 @@ struct TosaBuilder : DialectBuilder {
llvm::ArrayRef<int8_t> vec, llvm::ArrayRef<int64_t> shape);
mlir::Value getConst(
llvm::ArrayRef<float> vec, llvm::ArrayRef<int64_t> shape);
// Create a 32-bit float constant operator from a float
// Create a floating-point constant operator from a float
// The tensor will have the same rank as shape but all dimensions will
// have size 1 (differs from tensorflow impl.)
mlir::Value getSplattedConst(float val, llvm::ArrayRef<int64_t> shape = {});
// If dtype is provided, it also cast the value to the appropriate dtype.
mlir::Value getSplattedConst(float val, llvm::ArrayRef<int64_t> shape = {},
std::optional<mlir::Type> dtype = {});

// Creates a constant of shape <1x1x...x1> of rank `rank` with all values set
// to `value`.
Expand Down
48 changes: 47 additions & 1 deletion src/Conversion/ONNXToTOSA/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,51 @@ class ONNXDivOpLoweringToTOSA : public OpConversionPattern<ONNXDivOp> {
}
};

class ONNXHardSigmoidOpLoweringToTOSA
: public OpConversionPattern<ONNXHardSigmoidOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(ONNXHardSigmoidOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// ONNXHardSigmoid -> TOSA:
// - tosa.add(input, beta/alpha)
// - tosa.clamp(add) with min = 0, and max = 1/alpha
// - tosa.mul(clamp, alpha)
Value input = adaptor.getX();

auto resultType = op.getResult().getType().template cast<TensorType>();
auto resultElementType = resultType.getElementType();

TosaBuilder tosaBuilder(rewriter, op->getLoc());

auto alpha = adaptor.getAlpha();

auto betaOverAlpha = adaptor.getBeta();
betaOverAlpha.divide(alpha, APFloat::rmNearestTiesToEven);

APFloat oneOverAlpha(alpha.getSemantics(), 1);
oneOverAlpha.divide(alpha, APFloat::rmNearestTiesToEven);

Value constBetaOverAlpha =
tosaBuilder.getSplattedConst(betaOverAlpha.convertToDouble(),
resultType.getShape(), resultElementType);
Value constAlpha = tosaBuilder.getSplattedConst(
alpha.convertToDouble(), resultType.getShape(), resultElementType);

auto addOp =
tosaBuilder.binaryOp<mlir::tosa::AddOp>(input, constBetaOverAlpha);
Value clampOp = tosa::CreateOpAndInfer<mlir::tosa::ClampOp>(rewriter,
op->getLoc(), resultType, addOp, rewriter.getI64IntegerAttr(0),
rewriter.getI64IntegerAttr(oneOverAlpha.convertToDouble()),
rewriter.getF32FloatAttr(0),
rewriter.getF32FloatAttr(oneOverAlpha.convertToDouble()));
auto mulOp = tosaBuilder.mul(clampOp, constAlpha);

rewriter.replaceOp(op, {mulOp});
return success();
}
};

static void populateLoweringONNXElementwiseBinaryTemplateOpToTOSAPattern(
RewritePatternSet &patterns, TypeConverter &typeConverter,
MLIRContext *ctx) {
Expand Down Expand Up @@ -372,7 +417,8 @@ void populateLoweringONNXElementwiseOpToTOSAPattern(ConversionTarget &target,
MLIRContext *ctx) {
patterns.insert<ONNXReluOpLoweringToTOSA, ONNXLeakyReluOpLoweringToTOSA,
ONNXMulOpLoweringToTosa, ONNXClipOpLoweringToTOSA,
ONNXDivOpLoweringToTOSA>(typeConverter, ctx);
ONNXDivOpLoweringToTOSA, ONNXHardSigmoidOpLoweringToTOSA>(
typeConverter, ctx);

populateLoweringONNXElementwiseBinaryTemplateOpToTOSAPattern(
patterns, typeConverter, ctx);
Expand Down
56 changes: 56 additions & 0 deletions test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -363,3 +363,59 @@ func.func @test_erf_bf16(%arg0: tensor<3xbf16>) -> tensor<3xbf16> {
// CHECK-NEXT: return [[VAR_0_]] : tensor<3xbf16>
// CHECK-NEXT: }
}

// -----

// Default values: alpha = 0.2, beta = 0.5
func.func @test_hardsigmoid_default_values_f32(%arg0: tensor<3xf32>) -> tensor<3xf32> {
%0 = "onnx.HardSigmoid"(%arg0) : (tensor<3xf32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
// CHECK-LABEL: func.func @test_hardsigmoid_default_values_f32
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3xf32>) -> tensor<3xf32> {
// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<2.500000e+00> : tensor<3xf32>}> : () -> tensor<3xf32>
// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<2.000000e-01> : tensor<3xf32>}> : () -> tensor<3xf32>
// CHECK: [[VAR_2_:%.+]] = "tosa.add"([[PARAM_0_]], [[VAR_0_]]) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
// CHECK: [[VAR_3_:%.+]] = "tosa.clamp"([[VAR_2_]]) <{max_fp = 5.000000e+00 : f32, max_int = 5 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64}> : (tensor<3xf32>) -> tensor<3xf32>
// CHECK: [[VAR_4_:%.+]] = "tosa.mul"([[VAR_3_]], [[VAR_1_]]) <{shift = 0 : i32}> : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
// CHECK: return [[VAR_4_]] : tensor<3xf32>
}

func.func @test_hardsigmoid_default_values_f16(%arg0: tensor<3xf16>) -> tensor<3xf16> {
%0 = "onnx.HardSigmoid"(%arg0) : (tensor<3xf16>) -> tensor<3xf16>
return %0 : tensor<3xf16>
// CHECK-LABEL: func @test_hardsigmoid_default_values_f16
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3xf16>) -> tensor<3xf16>
// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<2.500000e+00> : tensor<3xf16>}> : () -> tensor<3xf16>
// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<1.999510e-01> : tensor<3xf16>}> : () -> tensor<3xf16>
// CHECK: [[VAR_2_:%.+]] = "tosa.add"([[PARAM_0_]], [[VAR_0_]]) : (tensor<3xf16>, tensor<3xf16>) -> tensor<3xf16>
// CHECK: [[VAR_3_:%.+]] = "tosa.clamp"([[VAR_2_]]) <{max_fp = 5.000000e+00 : f32, max_int = 5 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64}> : (tensor<3xf16>) -> tensor<3xf16>
// CHECK: [[VAR_4_:%.+]] = "tosa.mul"([[VAR_3_]], [[VAR_1_]]) <{shift = 0 : i32}> : (tensor<3xf16>, tensor<3xf16>) -> tensor<3xf16>
// CHECK: return [[VAR_4_]] : tensor<3xf16>
}

// alpha = 0.166666672, beta = 5.000000e-01
func.func @test_hardsigmoid_f32(%arg0: tensor<3xf32>) -> tensor<3xf32> {
%0 = "onnx.HardSigmoid"(%arg0) {alpha = 0.166666672 : f32, beta = 5.000000e-01 : f32} : (tensor<3xf32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
// CHECK-LABEL: func @test_hardsigmoid_f32
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3xf32>) -> tensor<3xf32>
// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<3.000000e+00> : tensor<3xf32>}> : () -> tensor<3xf32>
// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<0.166666672> : tensor<3xf32>}> : () -> tensor<3xf32>
// CHECK: [[VAR_2_:%.+]] = "tosa.add"([[PARAM_0_]], [[VAR_0_]]) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
// CHECK: [[VAR_3_:%.+]] = "tosa.clamp"([[VAR_2_]]) <{max_fp = 6.000000e+00 : f32, max_int = 6 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64}> : (tensor<3xf32>) -> tensor<3xf32>
// CHECK: [[VAR_4_:%.+]] = "tosa.mul"([[VAR_3_]], [[VAR_1_]]) <{shift = 0 : i32}> : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
// CHECK: return [[VAR_4_]] : tensor<3xf32>
}

func.func @test_hardsigmoid_f16(%arg0: tensor<3xf16>) -> tensor<3xf16> {
%0 = "onnx.HardSigmoid"(%arg0) {alpha = 0.166666672 : f32, beta = 5.000000e-01 : f32} : (tensor<3xf16>) -> tensor<3xf16>
return %0 : tensor<3xf16>
// CHECK-LABEL: func @test_hardsigmoid_f16
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3xf16>) -> tensor<3xf16>
// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<3.000000e+00> : tensor<3xf16>}> : () -> tensor<3xf16>
// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<1.666260e-01> : tensor<3xf16>}> : () -> tensor<3xf16>
// CHECK: [[VAR_2_:%.+]] = "tosa.add"([[PARAM_0_]], [[VAR_0_]]) : (tensor<3xf16>, tensor<3xf16>) -> tensor<3xf16>
// CHECK: [[VAR_3_:%.+]] = "tosa.clamp"([[VAR_2_]]) <{max_fp = 6.000000e+00 : f32, max_int = 6 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64}> : (tensor<3xf16>) -> tensor<3xf16>
// CHECK: [[VAR_4_:%.+]] = "tosa.mul"([[VAR_3_]], [[VAR_1_]]) <{shift = 0 : i32}> : (tensor<3xf16>, tensor<3xf16>) -> tensor<3xf16>
// CHECK: return [[VAR_4_]] : tensor<3xf16>
}

0 comments on commit 0c6a798

Please sign in to comment.