Skip to content

Commit

Permalink
Merge pull request #22 from Xilinx/matthias.onnx_to_tosa_mul
Browse files Browse the repository at this point in the history
Fix lowering of ONNX.Mul to tosa.mul when rank-broadcasting
  • Loading branch information
mgehre-amd authored Dec 21, 2023
2 parents 01ea7ca + 1f82b1e commit db951e0
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/Conversion/ONNXToTOSA/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,9 @@ class ONNXMulOpLoweringToTosa : public OpConversionPattern<ONNXMulOp> {
Value lhs = adaptor.getA();
Value rhs = adaptor.getB();

rewriter.replaceOpWithNewOp<mlir::tosa::MulOp>(
op, op.getType(), lhs, rhs, /*shift =*/0);
TosaBuilder tosaBuilder(rewriter, op->getLoc());
Value mulOp = tosaBuilder.mul(lhs, rhs);
rewriter.replaceOp(op, {mulOp});

return success();
}
Expand Down
22 changes: 22 additions & 0 deletions test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,28 @@ func.func @test_mul(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x1xf32>) -> t

// -----

func.func @test_mul_rank_broadcast(%arg0: tensor<13x21x1xf32>, %arg1: tensor<21x1xf32>) -> tensor<13x21x1xf32> {
%0 = "onnx.Mul"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<21x1xf32>) -> tensor<13x21x1xf32>
"func.return"(%0) : (tensor<13x21x1xf32>) -> ()
// CHECK-LABEL: func @test_mul_rank_broadcast
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xf32>, [[PARAM_1_:%.+]]: tensor<21x1xf32>) -> tensor<13x21x1xf32> {
// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.reshape"([[PARAM_1_]]) <{new_shape = array<i64: 1, 21, 1>}> : (tensor<21x1xf32>) -> tensor<1x21x1xf32>
// CHECK-NEXT: [[VAR_1_:%.+]] = "tosa.mul"([[PARAM_0_]], [[VAR_0_]]) <{shift = 0 : i32}> : (tensor<13x21x1xf32>, tensor<1x21x1xf32>) -> tensor<13x21x1xf32>
}

// -----

func.func @test_mul_rank_broadcast2(%arg0: tensor<21x1xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x1xf32> {
%0 = "onnx.Mul"(%arg0, %arg1) : (tensor<21x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xf32>
"func.return"(%0) : (tensor<13x21x1xf32>) -> ()
// CHECK-LABEL: func @test_mul_rank_broadcast2
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<21x1xf32>, [[PARAM_1_:%.+]]: tensor<13x21x1xf32>) -> tensor<13x21x1xf32> {
// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.reshape"([[PARAM_0_]]) <{new_shape = array<i64: 1, 21, 1>}> : (tensor<21x1xf32>) -> tensor<1x21x1xf32>
// CHECK-NEXT: [[VAR_1_:%.+]] = "tosa.mul"([[VAR_0_]], [[PARAM_1_]]) <{shift = 0 : i32}> : (tensor<1x21x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xf32>
}

// -----

func.func @test_div(%arg0: tensor<13x21x1xi32>, %arg1: tensor<13x21x1xi32>) -> tensor<13x21x1xi32> {
%0 = "onnx.Div"(%arg0, %arg1) : (tensor<13x21x1xi32>, tensor<13x21x1xi32>) -> tensor<13x21x1xi32>
"func.return"(%0) : (tensor<13x21x1xi32>) -> ()
Expand Down

0 comments on commit db951e0

Please sign in to comment.