Skip to content

Commit

Permalink
dialects: (linalg) add hidden region to BroadcastOp (#3840)
Browse files Browse the repository at this point in the history
Resolves pooling in #2959 ,
will be tested in #3837
  • Loading branch information
jorendumoulin authored and emmau678 committed Feb 6, 2025
1 parent 9eceb8c commit a216535
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ builtin.module {
builtin.module {
%0, %1 = "test.op"() : () -> (tensor<16xf32>, tensor<32x64x16xf32>)
// CHECK: Operation does not verify: Input rank plus added dimensions (2) does not match output rank (3)
%res_transpose = "linalg.broadcast"(%0, %1) {"dimensions" = array<i64: 1>} : (tensor<16xf32>, tensor<32x64x16xf32>) -> tensor<32x64x16xf32>
%res_broadcast = linalg.broadcast ins(%0 : tensor<16xf32>) outs(%1 : tensor<32x64x16xf32>) dimensions = [1]

}

Expand All @@ -53,7 +53,7 @@ builtin.module {
builtin.module {
%0, %1 = "test.op"() : () -> (tensor<16xf32>, tensor<16x64xf32>)
// CHECK: Operation does not verify: Dimension 0 is out of range. Expected range: [0, 1], got: 9
%res_transpose = "linalg.broadcast"(%0, %1) {"dimensions" = array<i64: 9>} : (tensor<16xf32>, tensor<16x64xf32>) -> tensor<16x64xf32>
%res_broadcast = linalg.broadcast ins(%0 : tensor<16xf32>) outs(%1 : tensor<16x64xf32>) dimensions = [9]

}

Expand All @@ -62,6 +62,6 @@ builtin.module {
builtin.module {
%0, %1 = "test.op"() : () -> (tensor<3x4x5xf32>, tensor<4x5x6x2xf32>)
// CHECK: Operation does not verify: input dimension 0 should match output dimension 0. input: 3, output: 4
%res_transpose = "linalg.broadcast"(%0, %1) {"dimensions" = array<i64: 1>} : (tensor<3x4x5xf32>, tensor<4x5x6x2xf32>) -> tensor<4x5x6x2xf32>
%res_broadcast = linalg.broadcast ins(%0 : tensor<3x4x5xf32>) outs(%1 : tensor<4x5x6x2xf32>) dimensions = [1]

}
9 changes: 9 additions & 0 deletions xdsl/dialects/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1069,6 +1069,8 @@ class BroadcastOp(IRDLOperation):
init = operand_def(base(MemRefType) | base(AnyTensorType))
result = var_result_def(AnyTensorType)

hidden_region = region_def("single_block")

dimensions = attr_def(DenseArrayBase)

def __init__(
Expand All @@ -1078,12 +1080,19 @@ def __init__(
dimensions: Attribute,
result: Attribute | None = None,
):
arg_types = NamedOpBase.body_arg_types((input, init))

@Builder.implicit_region(arg_types)
def hidden_region(args: tuple[BlockArgument, ...]) -> None:
YieldOp(args[0])

super().__init__(
attributes={
"dimensions": dimensions,
},
operands=(input, init),
result_types=(result,),
regions=(hidden_region,),
)

def verify_(self) -> None:
Expand Down

0 comments on commit a216535

Please sign in to comment.