Skip to content

Commit

Permalink
dialects: (builtin) remove AnyMemRefType (#3831)
Browse files Browse the repository at this point in the history
`MemRefType` now has a default generic, so `AnyMemRefType` is
unnecessary.
  • Loading branch information
alexarice authored Feb 4, 2025
1 parent 5025a61 commit e6f5649
Show file tree
Hide file tree
Showing 9 changed files with 21 additions and 26 deletions.
1 change: 0 additions & 1 deletion xdsl/dialects/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1975,7 +1975,6 @@ def constr(
)


AnyMemRefType: TypeAlias = MemRefType[Attribute]
AnyMemRefTypeConstr = BaseAttr[MemRefType[Attribute]](MemRefType)


Expand Down
7 changes: 3 additions & 4 deletions xdsl/dialects/csl/csl.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
AnyFloatAttrConstr,
AnyIntegerAttr,
AnyIntegerAttrConstr,
AnyMemRefType,
ArrayAttr,
BoolAttr,
ContainerType,
Expand Down Expand Up @@ -1344,9 +1343,9 @@ def typcheck(
sig_typ: Attribute | type[Attribute],
) -> bool:
if isinstance(sig_typ, type):
return (
sig_typ == DsdType and isa(op_typ, AnyMemRefType)
) or isinstance(op_typ, sig_typ)
return (sig_typ == DsdType and isa(op_typ, MemRefType)) or isinstance(
op_typ, sig_typ
)
else:
return op_typ == sig_typ

Expand Down
3 changes: 1 addition & 2 deletions xdsl/dialects/csl/csl_stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from xdsl.dialects.builtin import (
AnyFloat,
AnyIntegerAttr,
AnyMemRefType,
AnyMemRefTypeConstr,
AnyTensorTypeConstr,
Float16Type,
Expand Down Expand Up @@ -554,7 +553,7 @@ def parse(cls, parser: Parser):
],
properties=props,
)
elif isattr(res_type, base(AnyMemRefType)):
elif isattr(res_type, base(MemRefType)):
return cls.build(
operands=[temp],
result_types=[
Expand Down
9 changes: 4 additions & 5 deletions xdsl/dialects/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from xdsl.dialects.builtin import (
AffineMapAttr,
AnyFloat,
AnyMemRefType,
AnyTensorType,
ArrayAttr,
DenseArrayBase,
Expand Down Expand Up @@ -729,8 +728,8 @@ class TransposeOp(IRDLOperation):

name = "linalg.transpose"

input = operand_def(base(AnyMemRefType) | base(AnyTensorType))
init = operand_def(base(AnyMemRefType) | base(AnyTensorType))
input = operand_def(base(MemRefType) | base(AnyTensorType))
init = operand_def(base(MemRefType) | base(AnyTensorType))
result = var_result_def(AnyTensorType)

permutation = attr_def(DenseArrayBase)
Expand Down Expand Up @@ -1062,8 +1061,8 @@ class BroadcastOp(IRDLOperation):

name = "linalg.broadcast"

input = operand_def(base(AnyMemRefType) | base(AnyTensorType))
init = operand_def(base(AnyMemRefType) | base(AnyTensorType))
input = operand_def(base(MemRefType) | base(AnyTensorType))
init = operand_def(base(MemRefType) | base(AnyTensorType))
result = var_result_def(AnyTensorType)

dimensions = attr_def(DenseArrayBase)
Expand Down
8 changes: 4 additions & 4 deletions xdsl/transforms/csl_stencil_bufferize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from xdsl.context import MLContext
from xdsl.dialects import arith, bufferization, func, linalg, memref, stencil, tensor
from xdsl.dialects.builtin import (
AnyMemRefType,
AnyTensorType,
AnyTensorTypeConstr,
DenseArrayBase,
DenseIntOrFPElementsAttr,
FunctionType,
MemRefType,
ModuleOp,
TensorType,
i64,
Expand Down Expand Up @@ -57,7 +57,7 @@ def to_tensor_op(
op: SSAValue, writable: bool = False, restrict: bool = True
) -> bufferization.ToTensorOp:
"""Creates a `bufferization.to_tensor` operation."""
assert isa(op.type, AnyMemRefType)
assert isa(op.type, MemRefType)
return bufferization.ToTensorOp(op, restrict, writable)


Expand Down Expand Up @@ -418,8 +418,8 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter,
linalg_op := yld_arg.op.tensor.op,
linalg.NamedOpBase | linalg.GenericOp,
)
or not isa(arg_t := arg.type, AnyMemRefType)
or not isa(yld_arg.type, AnyMemRefType)
or not isa(arg_t := arg.type, MemRefType)
or not isa(yld_arg.type, MemRefType)
):
new_dest.append(arg)
new_yield_args.append(yld_arg)
Expand Down
4 changes: 2 additions & 2 deletions xdsl/transforms/csl_stencil_to_csl_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
from xdsl.context import MLContext
from xdsl.dialects import arith, builtin, func, llvm, memref, stencil
from xdsl.dialects.builtin import (
AnyMemRefType,
AnyMemRefTypeConstr,
AnyTensorTypeConstr,
ArrayAttr,
DictionaryAttr,
IndexType,
IntegerAttr,
IntegerType,
MemRefType,
ShapedType,
Signedness,
StringAttr,
Expand Down Expand Up @@ -396,7 +396,7 @@ def match_and_rewrite(self, op: llvm.StoreOp, rewriter: PatternRewriter, /):
or not (isinstance(start_call := end_call.arguments[0].owner, func.CallOp))
or not start_call.callee.string_value() == TIMER_START
or not (wrapper := _get_module_wrapper(op))
or not isa(op.ptr.type, AnyMemRefType)
or not isa(op.ptr.type, MemRefType)
):
return

Expand Down
6 changes: 3 additions & 3 deletions xdsl/transforms/linalg_to_csl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from xdsl.dialects.builtin import (
AnyFloatAttr,
AnyIntegerAttr,
AnyMemRefType,
DenseIntOrFPElementsAttr,
Float16Type,
Float32Type,
MemRefType,
ModuleOp,
)
from xdsl.dialects.csl import csl
Expand Down Expand Up @@ -62,7 +62,7 @@ def transform_op(
f16: type[csl.BuiltinDsdOp],
f32: type[csl.BuiltinDsdOp],
):
if not isa(target_t := op.outputs.types[0], AnyMemRefType):
if not isa(target_t := op.outputs.types[0], MemRefType):
return

builtin = match_op_for_precision(target_t.get_element_type(), f16, f32)
Expand Down Expand Up @@ -91,7 +91,7 @@ class ConvertLinalgGenericFMAPass(RewritePattern):

@op_type_rewrite_pattern
def match_and_rewrite(self, op: linalg.GenericOp, rewriter: PatternRewriter, /):
if not self.is_fma(op) or not isa(op.outputs.types[0], AnyMemRefType):
if not self.is_fma(op) or not isa(op.outputs.types[0], MemRefType):
return

# one of the factors must be a scalar const, which the csl function signatures require
Expand Down
4 changes: 2 additions & 2 deletions xdsl/transforms/lower_csl_stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
from xdsl.dialects.builtin import (
AffineMapAttr,
AnyFloatAttr,
AnyMemRefType,
DenseIntOrFPElementsAttr,
Float16Type,
Float32Type,
FloatAttr,
FunctionType,
IndexType,
IntegerAttr,
MemRefType,
ModuleOp,
UnrealizedConversionCastOp,
i16,
Expand Down Expand Up @@ -445,7 +445,7 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter,
return

if (
not isattr(accumulator.type, AnyMemRefType)
not isattr(accumulator.type, MemRefType)
or not isinstance(op.accumulator, OpResult)
or not isinstance(alloc := op.accumulator.op, memref.AllocOp)
):
Expand Down
5 changes: 2 additions & 3 deletions xdsl/transforms/memref_to_dsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from xdsl.dialects import arith, builtin, csl, memref
from xdsl.dialects.builtin import (
AffineMapAttr,
AnyMemRefType,
ArrayAttr,
Float16Type,
Float32Type,
Expand Down Expand Up @@ -124,8 +123,8 @@ class LowerSubviewOpPass(RewritePattern):

@op_type_rewrite_pattern
def match_and_rewrite(self, op: memref.SubviewOp, rewriter: PatternRewriter, /):
assert isa(op.source.type, AnyMemRefType)
assert isa(op.result.type, AnyMemRefType)
assert isa(op.source.type, MemRefType)
assert isa(op.result.type, MemRefType)

if len(op.result.type.get_shape()) == 1 and len(op.source.type.get_shape()) > 1:
# 1d subview onto a nd memref
Expand Down

0 comments on commit e6f5649

Please sign in to comment.