diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py index 8fb1227ee80ff50..946094e2e9f6911 100644 --- a/mlir/python/mlir/dialects/linalg/__init__.py +++ b/mlir/python/mlir/dialects/linalg/__init__.py @@ -10,6 +10,7 @@ # DSL -> YAML -> tblgen -> pytblgen -> build/.../_linalg_ops_gen.py. from .._linalg_ops_gen import * from .._linalg_enum_gen import * +from .._linalg_enum_gen import _iteratortypeenum # These are the ground truth functions defined as: # ``` @@ -58,6 +59,7 @@ from ...ir import * from .._ods_common import get_op_result_or_value as _get_op_result_or_value +from ...extras.meta import region_op def transpose( @@ -102,3 +104,45 @@ def broadcast( ) fill_builtin_region(op.operation) return op + + +@register_attribute_builder("IteratorTypeArrayAttr") +def _IteratorTypeArrayAttr(x, context): + return ArrayAttr.get([_iteratortypeenum(v, context) for v in x]) + + +class GenericOp(GenericOp): + def __init__( + self, + inputs, + outputs, + indexing_maps, + iterator_types, + *, + doc=None, + library_call=None, + loc=None, + ip=None, + ): + result_types = [] + if isinstance(outputs[0].type, RankedTensorType): + result_types = [o.type for o in outputs] + + super().__init__( + result_types, + inputs, + outputs, + indexing_maps, + iterator_types, + doc=doc, + library_call=library_call, + loc=loc, + ip=ip, + ) + element_types = [i.type.element_type for i in inputs] + [ + o.type.element_type for o in outputs + ] + self.regions[0].blocks.append(*element_types) + + +generic = region_op(GenericOp, terminator=YieldOp) diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py index 72045a07b2da800..b7e0f2884bb2492 100644 --- a/mlir/test/python/dialects/linalg/ops.py +++ b/mlir/test/python/dialects/linalg/ops.py @@ -84,6 +84,7 @@ def named_form(lhs, rhs): print(module) + # CHECK-LABEL: TEST: testIdentityRegionOps @run def testIdentityRegionOps(): @@ -161,3 +162,62 @@ def broadcast_op(op1, op2, op3): op5 = linalg.broadcast(op3, outs=[op2], dimensions=[0]) print(module) + + +# CHECK-LABEL: TEST: testGenericOp +@run +def testGenericOp(): + with Context(), Location.unknown(): + module = Module.create() + f32 = F32Type.get() + with InsertionPoint(module.body): + id_map = AffineMap.get_identity(2) + # CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<16x16xf32> + # CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<16x16xf32> + x = tensor.empty((16, 16), f32) + y = tensor.empty((16, 16), f32) + + # CHECK: %[[VAL_3:*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_0]] : tensor<16x16xf32>) outs(%[[VAL_1]] : tensor<16x16xf32>) { + # CHECK: ^bb0(%in: f32, %out: f32): + # CHECK: linalg.yield %in : f32 + # CHECK: } -> tensor<16x16xf32> + @linalg.generic( + [x], + [y], + [id_map, id_map], + [linalg.IteratorType.parallel, linalg.IteratorType.parallel], + ) + def f(x, y): + return x + + assert isinstance(f, Value) + + # CHECK: %[[VAL_3:.*]] = tensor.empty() : tensor<16x16x16xf32> + z = tensor.empty((16, 16, 16), f32) + + minor_id = AffineMap.get_minor_identity(3, 2) + id_map = AffineMap.get_identity(3) + + # CHECK: %%[[VAL_4:.*]]:2 = linalg.generic {indexing_maps = [#map1, #map2, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_0]] : tensor<16x16xf32>) outs(%[[VAL_3]], %[[VAL_3]] : tensor<16x16x16xf32>, tensor<16x16x16xf32>) { + # CHECK: ^bb0(%in: f32, %out: f32, %out_0: f32): + # CHECK: linalg.yield %in, %out : f32, f32 + # CHECK: } -> (tensor<16x16x16xf32>, tensor<16x16x16xf32>) + @linalg.generic( + [x], + [z, z], + [minor_id, id_map, id_map], + [ + linalg.IteratorType.parallel, + linalg.IteratorType.parallel, + linalg.IteratorType.parallel, + ], + ) + def g(x, z1, z2): + return x, z1 + + assert isinstance(g, OpResultList) + assert len(g) == 2 + assert isinstance(g[0].type, RankedTensorType) + assert isinstance(g[1].type, RankedTensorType) + + print(module)