Skip to content

Commit

Permalink
dialects: (linalg) add hidden region to transpose op (#3838)
Browse files Browse the repository at this point in the history
This adds a hidden region to the linalg.transpose op to ensure correct
generic printing
Also changes permutation to a property instead of attribute.

This resolves the transpose op in #2959 

This has now been checked manually, and will be put in ci with #3837
(but for that 3 other ops need to be fixed, PRs incoming...)
  • Loading branch information
jorendumoulin authored Feb 5, 2025
1 parent f7458f8 commit f0ff49f
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ builtin.module {
%0, %1 = "test.op"() : () -> (tensor<16x64xf32>, tensor<64x16x1xf32>)

// CHECK: Operation does not verify: Input rank (2) does not match output rank (3)
%res_transpose = "linalg.transpose"(%0, %1) {"permutation" = array<i64: 1, 0>} : (tensor<16x64xf32>, tensor<64x16x1xf32>) -> tensor<64x16x1xf32>
%res_transpose = linalg.transpose ins(%0 : tensor<16x64xf32>) outs(%1 : tensor<64x16x1xf32>) permutation = [1, 0]

}

Expand All @@ -25,7 +25,7 @@ builtin.module {
%0, %1 = "test.op"() : () -> (tensor<16x64xf32>, tensor<64x16xf32>)

// CHECK: Operation does not verify: Input rank (2) does not match size of permutation (3)
%res_transpose = "linalg.transpose"(%0, %1) {"permutation" = array<i64: 1, 2, 3>} : (tensor<16x64xf32>, tensor<64x16xf32>) -> tensor<64x16xf32>
%res_transpose = linalg.transpose ins(%0 : tensor<16x64xf32>) outs(%1 : tensor<64x16xf32>) permutation = [1, 2, 3]

}

Expand All @@ -35,7 +35,7 @@ builtin.module {
%0, %1 = "test.op"() : () -> (tensor<16x32x64xf32>, tensor<32x64x16xf32>)

// CHECK: Operation does not verify: dim(result, 1) = 64 doesn't match dim(input, permutation[1]) = 32
%res_transpose = "linalg.transpose"(%0, %1) {"permutation" = array<i64: 1, 1, 2>} : (tensor<16x32x64xf32>, tensor<32x64x16xf32>) -> tensor<32x64x16xf32>
%res_transpose = linalg.transpose ins(%0 : tensor<16x32x64xf32>) outs(%1 : tensor<32x64x16xf32>) permutation = [1, 1, 2]

}

Expand Down
13 changes: 11 additions & 2 deletions xdsl/dialects/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,7 +732,9 @@ class TransposeOp(IRDLOperation):
init = operand_def(base(MemRefType) | base(AnyTensorType))
result = var_result_def(AnyTensorType)

permutation = attr_def(DenseArrayBase)
hidden_region = region_def("single_block")

permutation = prop_def(DenseArrayBase)

def __init__(
self,
Expand All @@ -741,12 +743,19 @@ def __init__(
permutation: Attribute,
result: Attribute | None = None,
):
arg_types = NamedOpBase.body_arg_types((input, init))

@Builder.implicit_region(arg_types)
def hidden_region(args: tuple[BlockArgument, ...]) -> None:
YieldOp(args[0])

super().__init__(
attributes={
properties={
"permutation": permutation,
},
operands=(input, init),
result_types=(result,),
regions=(hidden_region,),
)

def verify_(self) -> None:
Expand Down

0 comments on commit f0ff49f

Please sign in to comment.