Skip to content

Commit

Permalink
Merge conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
Ferdinand Lemaire committed Jan 3, 2024
2 parents d29b8ba + 0c6a798 commit 0764b8f
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 9 deletions.
7 changes: 6 additions & 1 deletion src/Conversion/ONNXToTOSA/DialectBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,17 @@ Value TosaBuilder::getConst(ArrayRef<float> vec, ArrayRef<int64_t> shape) {
return constOp;
}

Value TosaBuilder::getSplattedConst(float val, llvm::ArrayRef<int64_t> shape) {
Value TosaBuilder::getSplattedConst(
float val, llvm::ArrayRef<int64_t> shape, std::optional<Type> dtype) {
auto constType = tosa::reduceAxisToOne(shape, rewriter().getF32Type());
auto constAttr = DenseElementsAttr::get(constType, val);

auto constOp =
rewriter().create<mlir::tosa::ConstOp>(loc(), constType, constAttr);

if (dtype)
return rewriter().createOrFold<mlir::tosa::CastOp>(
loc(), RankedTensorType::get(shape, *dtype), constOp);
return constOp;
}

Expand Down
6 changes: 4 additions & 2 deletions src/Conversion/ONNXToTOSA/DialectBuilder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,12 @@ struct TosaBuilder : DialectBuilder {
llvm::ArrayRef<int8_t> vec, llvm::ArrayRef<int64_t> shape);
mlir::Value getConst(
llvm::ArrayRef<float> vec, llvm::ArrayRef<int64_t> shape);
// Create a 32-bit float constant operator from a float
// Create a floating-point constant operator from a float
// The tensor will have the same rank as shape but all dimensions will
// have size 1 (differs from tensorflow impl.)
mlir::Value getSplattedConst(float val, llvm::ArrayRef<int64_t> shape = {});
// If dtype is provided, it also cast the value to the appropriate dtype.
mlir::Value getSplattedConst(float val, llvm::ArrayRef<int64_t> shape = {},
std::optional<mlir::Type> dtype = {});

// Creates a constant of shape <1x1x...x1> of rank `rank` with all values set
// to `value`.
Expand Down
61 changes: 55 additions & 6 deletions src/Conversion/ONNXToTOSA/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -334,14 +334,62 @@ class ONNXDivOpLoweringToTOSA : public OpConversionPattern<ONNXDivOp> {
}
};

class ONNXHardSigmoidOpLoweringToTOSA
: public OpConversionPattern<ONNXHardSigmoidOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(ONNXHardSigmoidOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// ONNXHardSigmoid -> TOSA:
// - tosa.add(input, beta/alpha)
// - tosa.clamp(add) with min = 0, and max = 1/alpha
// - tosa.mul(clamp, alpha)
Value input = adaptor.getX();

auto resultType = op.getResult().getType().template cast<TensorType>();
auto resultElementType = resultType.getElementType();

TosaBuilder tosaBuilder(rewriter, op->getLoc());

auto alpha = adaptor.getAlpha();

auto betaOverAlpha = adaptor.getBeta();
betaOverAlpha.divide(alpha, APFloat::rmNearestTiesToEven);

APFloat oneOverAlpha(alpha.getSemantics(), 1);
oneOverAlpha.divide(alpha, APFloat::rmNearestTiesToEven);

Value constBetaOverAlpha =
tosaBuilder.getSplattedConst(betaOverAlpha.convertToDouble(),
resultType.getShape(), resultElementType);
Value constAlpha = tosaBuilder.getSplattedConst(
alpha.convertToDouble(), resultType.getShape(), resultElementType);

auto addOp =
tosaBuilder.binaryOp<mlir::tosa::AddOp>(input, constBetaOverAlpha);
Value clampOp = tosa::CreateOpAndInfer<mlir::tosa::ClampOp>(rewriter,
op->getLoc(), resultType, addOp, rewriter.getI64IntegerAttr(0),
rewriter.getI64IntegerAttr(oneOverAlpha.convertToDouble()),
rewriter.getF32FloatAttr(0),
rewriter.getF32FloatAttr(oneOverAlpha.convertToDouble()));
auto mulOp = tosaBuilder.mul(clampOp, constAlpha);

rewriter.replaceOp(op, {mulOp});
return success();
}
};

static void populateLoweringONNXElementwiseBinaryTemplateOpToTOSAPattern(
RewritePatternSet &patterns, TypeConverter &typeConverter,
MLIRContext *ctx) {
patterns.insert<
ONNXBinaryElementwiseOpLoweringToTOSA<ONNXAndOp, mlir::tosa::LogicalAndOp>,
ONNXBinaryElementwiseOpLoweringToTOSA<ONNXBitwiseAndOp, mlir::tosa::BitwiseAndOp>,
ONNXBinaryElementwiseOpLoweringToTOSA<ONNXOrOp, mlir::tosa::LogicalOrOp>,
ONNXBinaryElementwiseOpLoweringToTOSA<ONNXBitwiseOrOp, mlir::tosa::BitwiseOrOp>,
patterns.insert<ONNXBinaryElementwiseOpLoweringToTOSA<ONNXAndOp,
mlir::tosa::LogicalAndOp>,
ONNXBinaryElementwiseOpLoweringToTOSA<ONNXBitwiseAndOp,
mlir::tosa::BitwiseAndOp>,
ONNXBinaryElementwiseOpLoweringToTOSA<ONNXOrOp,
mlir::tosa::LogicalOrOp>,
ONNXBinaryElementwiseOpLoweringToTOSA<ONNXBitwiseOrOp,
mlir::tosa::BitwiseOrOp>,
ONNXBinaryElementwiseOpLoweringToTOSA<ONNXAddOp, mlir::tosa::AddOp>,
ONNXBinaryElementwiseOpLoweringToTOSA<ONNXSubOp, mlir::tosa::SubOp>>(
typeConverter, ctx);
Expand Down Expand Up @@ -377,7 +425,8 @@ void populateLoweringONNXElementwiseOpToTOSAPattern(ConversionTarget &target,
MLIRContext *ctx) {
patterns.insert<ONNXReluOpLoweringToTOSA, ONNXLeakyReluOpLoweringToTOSA,
ONNXMulOpLoweringToTosa, ONNXClipOpLoweringToTOSA,
ONNXDivOpLoweringToTOSA>(typeConverter, ctx);
ONNXDivOpLoweringToTOSA, ONNXHardSigmoidOpLoweringToTOSA>(
typeConverter, ctx);

populateLoweringONNXElementwiseBinaryTemplateOpToTOSAPattern(
patterns, typeConverter, ctx);
Expand Down
22 changes: 22 additions & 0 deletions src/Transform/ONNX/Decompose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,26 @@ struct ConcatFusePattern : public ConversionPattern {
}
};

// ONNXHardSwishOp(input) can be decomposed as:
// input * ONNXHardSigmoid input, with alpha = 1/6 and beta = 0.5.
struct DecomposeHardSwishPattern : public ConversionPattern {
DecomposeHardSwishPattern(MLIRContext *context)
: ConversionPattern(ONNXHardSwishOp::getOperationName(), 4, context) {}
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {

ONNXHardSwishOp hardSwishOp = ::llvm::dyn_cast<ONNXHardSwishOp>(op);

auto input = hardSwishOp.getX();
auto hardSigmoid = rewriter.create<ONNXHardSigmoidOp>(op->getLoc(),
hardSwishOp.getType(), input, rewriter.getF32FloatAttr(1.0 / 6.0),
rewriter.getF32FloatAttr(0.5));
rewriter.replaceOpWithNewOp<ONNXMulOp>(
op, hardSwishOp.getType(), input, hardSigmoid);
return success();
}
};

// Decompose the custom op FusedMatMul that is produced by ONNXRuntime.
// According to FusedMatMul specification, it is the result of fusing MatMul and
// Transpose:
Expand Down Expand Up @@ -788,6 +808,7 @@ void DecomposeONNXToONNXPass::runOnOperation() {
target.addIllegalOp<ONNXUpsampleOp>();
target.addIllegalOp<ONNXUpsampleV7Op>();
target.addIllegalOp<ONNXUnsqueezeV11Op>();
target.addIllegalOp<ONNXHardSwishOp>();
target.addDynamicallyLegalOp<ONNXConcatOp>([](ONNXConcatOp op) {
ONNXShapeOp shapeOp = NULL;
ONNXTransposeOp transposeOp = NULL;
Expand Down Expand Up @@ -827,6 +848,7 @@ void DecomposeONNXToONNXPass::runOnOperation() {
populateWithGenerated(patterns);
patterns.insert<onnx_mlir::DecomposeEinsumPattern>(&getContext());
patterns.insert<ConcatFusePattern>(&getContext());
patterns.insert<DecomposeHardSwishPattern>(&getContext());
// Decompose CustomOp FusedMatMul introduced by onnxruntime:
// https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.FusedMatMul
patterns.insert<CustomOpFuseMatMulPattern>(&getContext());
Expand Down
55 changes: 55 additions & 0 deletions test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,61 @@ func.func @test_erf_bf16(%arg0: tensor<3xbf16>) -> tensor<3xbf16> {

// -----

// Default values: alpha = 0.2, beta = 0.5
func.func @test_hardsigmoid_default_values_f32(%arg0: tensor<3xf32>) -> tensor<3xf32> {
%0 = "onnx.HardSigmoid"(%arg0) : (tensor<3xf32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
// CHECK-LABEL: func.func @test_hardsigmoid_default_values_f32
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3xf32>) -> tensor<3xf32> {
// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<2.500000e+00> : tensor<3xf32>}> : () -> tensor<3xf32>
// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<2.000000e-01> : tensor<3xf32>}> : () -> tensor<3xf32>
// CHECK: [[VAR_2_:%.+]] = "tosa.add"([[PARAM_0_]], [[VAR_0_]]) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
// CHECK: [[VAR_3_:%.+]] = "tosa.clamp"([[VAR_2_]]) <{max_fp = 5.000000e+00 : f32, max_int = 5 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64}> : (tensor<3xf32>) -> tensor<3xf32>
// CHECK: [[VAR_4_:%.+]] = "tosa.mul"([[VAR_3_]], [[VAR_1_]]) <{shift = 0 : i32}> : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
// CHECK: return [[VAR_4_]] : tensor<3xf32>
}

func.func @test_hardsigmoid_default_values_f16(%arg0: tensor<3xf16>) -> tensor<3xf16> {
%0 = "onnx.HardSigmoid"(%arg0) : (tensor<3xf16>) -> tensor<3xf16>
return %0 : tensor<3xf16>
// CHECK-LABEL: func @test_hardsigmoid_default_values_f16
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3xf16>) -> tensor<3xf16>
// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<2.500000e+00> : tensor<3xf16>}> : () -> tensor<3xf16>
// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<1.999510e-01> : tensor<3xf16>}> : () -> tensor<3xf16>
// CHECK: [[VAR_2_:%.+]] = "tosa.add"([[PARAM_0_]], [[VAR_0_]]) : (tensor<3xf16>, tensor<3xf16>) -> tensor<3xf16>
// CHECK: [[VAR_3_:%.+]] = "tosa.clamp"([[VAR_2_]]) <{max_fp = 5.000000e+00 : f32, max_int = 5 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64}> : (tensor<3xf16>) -> tensor<3xf16>
// CHECK: [[VAR_4_:%.+]] = "tosa.mul"([[VAR_3_]], [[VAR_1_]]) <{shift = 0 : i32}> : (tensor<3xf16>, tensor<3xf16>) -> tensor<3xf16>
// CHECK: return [[VAR_4_]] : tensor<3xf16>
}

// alpha = 0.166666672, beta = 5.000000e-01
func.func @test_hardsigmoid_f32(%arg0: tensor<3xf32>) -> tensor<3xf32> {
%0 = "onnx.HardSigmoid"(%arg0) {alpha = 0.166666672 : f32, beta = 5.000000e-01 : f32} : (tensor<3xf32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
// CHECK-LABEL: func @test_hardsigmoid_f32
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3xf32>) -> tensor<3xf32>
// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<3.000000e+00> : tensor<3xf32>}> : () -> tensor<3xf32>
// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<0.166666672> : tensor<3xf32>}> : () -> tensor<3xf32>
// CHECK: [[VAR_2_:%.+]] = "tosa.add"([[PARAM_0_]], [[VAR_0_]]) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
// CHECK: [[VAR_3_:%.+]] = "tosa.clamp"([[VAR_2_]]) <{max_fp = 6.000000e+00 : f32, max_int = 6 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64}> : (tensor<3xf32>) -> tensor<3xf32>
// CHECK: [[VAR_4_:%.+]] = "tosa.mul"([[VAR_3_]], [[VAR_1_]]) <{shift = 0 : i32}> : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
// CHECK: return [[VAR_4_]] : tensor<3xf32>
}

func.func @test_hardsigmoid_f16(%arg0: tensor<3xf16>) -> tensor<3xf16> {
%0 = "onnx.HardSigmoid"(%arg0) {alpha = 0.166666672 : f32, beta = 5.000000e-01 : f32} : (tensor<3xf16>) -> tensor<3xf16>
return %0 : tensor<3xf16>
// CHECK-LABEL: func @test_hardsigmoid_f16
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3xf16>) -> tensor<3xf16>
// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<3.000000e+00> : tensor<3xf16>}> : () -> tensor<3xf16>
// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<1.666260e-01> : tensor<3xf16>}> : () -> tensor<3xf16>
// CHECK: [[VAR_2_:%.+]] = "tosa.add"([[PARAM_0_]], [[VAR_0_]]) : (tensor<3xf16>, tensor<3xf16>) -> tensor<3xf16>
// CHECK: [[VAR_3_:%.+]] = "tosa.clamp"([[VAR_2_]]) <{max_fp = 6.000000e+00 : f32, max_int = 6 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64}> : (tensor<3xf16>) -> tensor<3xf16>
// CHECK: [[VAR_4_:%.+]] = "tosa.mul"([[VAR_3_]], [[VAR_1_]]) <{shift = 0 : i32}> : (tensor<3xf16>, tensor<3xf16>) -> tensor<3xf16>
// CHECK: return [[VAR_4_]] : tensor<3xf16>
}
// -----

func.func @test_and(%arg0: tensor<13x21x1xi1>, %arg1: tensor<13x21x1xi1>) -> tensor<13x21x1xi1> {
%0 = "onnx.And"(%arg0, %arg1) : (tensor<13x21x1xi1>, tensor<13x21x1xi1>) -> tensor<13x21x1xi1>
"func.return"(%0) : (tensor<13x21x1xi1>) -> ()
Expand Down
12 changes: 12 additions & 0 deletions test/mlir/onnx/onnx_decompose.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -480,3 +480,15 @@ func.func @test_concatfuse_2(%arg0: tensor<?x20xf32>, %arg1: tensor<?x30xf32>) -
// CHECK: onnx.Return [[VAR_1_]], [[VAR_2_]] : tensor<2xi64>, tensor<?x50xf32>
// CHECK: }
}

// -----

func.func @test_hardswish_f32(%arg0: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
%0 = "onnx.HardSwish"(%arg0) : (tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
return %0 : tensor<?x?x?xf32>
// CHECK-LABEL: func @test_hardswish_f32
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
// CHECK: [[VAR_0_:%.+]] = "onnx.HardSigmoid"([[PARAM_0_]]) {alpha = 0.166666672 : f32, beta = 5.000000e-01 : f32} : (tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
// CHECK: [[VAR_1_:%.+]] = "onnx.Mul"([[PARAM_0_]], [[VAR_0_]]) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
// CHECK: return [[VAR_1_]] : tensor<?x?x?xf32>
}

0 comments on commit 0764b8f

Please sign in to comment.