Skip to content

Commit

Permalink
[mlir][python] implement GenericOp bindings
Browse files Browse the repository at this point in the history
  • Loading branch information
makslevental committed Jan 27, 2025
1 parent c1ec5be commit 206cb67
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 0 deletions.
44 changes: 44 additions & 0 deletions mlir/python/mlir/dialects/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# DSL -> YAML -> tblgen -> pytblgen -> build/.../_linalg_ops_gen.py.
from .._linalg_ops_gen import *
from .._linalg_enum_gen import *
from .._linalg_enum_gen import _iteratortypeenum

# These are the ground truth functions defined as:
# ```
Expand Down Expand Up @@ -58,6 +59,7 @@

from ...ir import *
from .._ods_common import get_op_result_or_value as _get_op_result_or_value
from ...extras.meta import region_op


def transpose(
Expand Down Expand Up @@ -102,3 +104,45 @@ def broadcast(
)
fill_builtin_region(op.operation)
return op


@register_attribute_builder("IteratorTypeArrayAttr")
def _IteratorTypeArrayAttr(x, context):
return ArrayAttr.get([_iteratortypeenum(v, context) for v in x])


class GenericOp(GenericOp):
def __init__(
self,
inputs,
outputs,
indexing_maps,
iterator_types,
*,
doc=None,
library_call=None,
loc=None,
ip=None,
):
result_types = []
if isinstance(outputs[0].type, RankedTensorType):
result_types = [o.type for o in outputs]

super().__init__(
result_types,
inputs,
outputs,
indexing_maps,
iterator_types,
doc=doc,
library_call=library_call,
loc=loc,
ip=ip,
)
element_types = [i.type.element_type for i in inputs] + [
o.type.element_type for o in outputs
]
self.regions[0].blocks.append(*element_types)


generic = region_op(GenericOp, terminator=YieldOp)
60 changes: 60 additions & 0 deletions mlir/test/python/dialects/linalg/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def named_form(lhs, rhs):

print(module)


# CHECK-LABEL: TEST: testIdentityRegionOps
@run
def testIdentityRegionOps():
Expand Down Expand Up @@ -161,3 +162,62 @@ def broadcast_op(op1, op2, op3):
op5 = linalg.broadcast(op3, outs=[op2], dimensions=[0])

print(module)


# CHECK-LABEL: TEST: testGenericOp
@run
def testGenericOp():
with Context(), Location.unknown():
module = Module.create()
f32 = F32Type.get()
with InsertionPoint(module.body):
id_map = AffineMap.get_identity(2)
# CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<16x16xf32>
# CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<16x16xf32>
x = tensor.empty((16, 16), f32)
y = tensor.empty((16, 16), f32)

# CHECK: %[[VAL_3:*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_0]] : tensor<16x16xf32>) outs(%[[VAL_1]] : tensor<16x16xf32>) {
# CHECK: ^bb0(%in: f32, %out: f32):
# CHECK: linalg.yield %in : f32
# CHECK: } -> tensor<16x16xf32>
@linalg.generic(
[x],
[y],
[id_map, id_map],
[linalg.IteratorType.parallel, linalg.IteratorType.parallel],
)
def f(x, y):
return x

assert isinstance(f, Value)

# CHECK: %[[VAL_3:.*]] = tensor.empty() : tensor<16x16x16xf32>
z = tensor.empty((16, 16, 16), f32)

minor_id = AffineMap.get_minor_identity(3, 2)
id_map = AffineMap.get_identity(3)

# CHECK: %%[[VAL_4:.*]]:2 = linalg.generic {indexing_maps = [#map1, #map2, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_0]] : tensor<16x16xf32>) outs(%[[VAL_3]], %[[VAL_3]] : tensor<16x16x16xf32>, tensor<16x16x16xf32>) {
# CHECK: ^bb0(%in: f32, %out: f32, %out_0: f32):
# CHECK: linalg.yield %in, %out : f32, f32
# CHECK: } -> (tensor<16x16x16xf32>, tensor<16x16x16xf32>)
@linalg.generic(
[x],
[z, z],
[minor_id, id_map, id_map],
[
linalg.IteratorType.parallel,
linalg.IteratorType.parallel,
linalg.IteratorType.parallel,
],
)
def g(x, z1, z2):
return x, z1

assert isinstance(g, OpResultList)
assert len(g) == 2
assert isinstance(g[0].type, RankedTensorType)
assert isinstance(g[1].type, RankedTensorType)

print(module)

0 comments on commit 206cb67

Please sign in to comment.