From e6f5649e0d5ef127422ecaed6990cbc3a4229b5b Mon Sep 17 00:00:00 2001 From: Alex Rice Date: Tue, 4 Feb 2025 15:22:31 +0000 Subject: [PATCH] dialects: (builtin) remove AnyMemRefType (#3831) `MemRefType` now has a default generic, so `AnyMemRefType` is unnecessary. --- xdsl/dialects/builtin.py | 1 - xdsl/dialects/csl/csl.py | 7 +++---- xdsl/dialects/csl/csl_stencil.py | 3 +-- xdsl/dialects/linalg.py | 9 ++++----- xdsl/transforms/csl_stencil_bufferize.py | 8 ++++---- xdsl/transforms/csl_stencil_to_csl_wrapper.py | 4 ++-- xdsl/transforms/linalg_to_csl.py | 6 +++--- xdsl/transforms/lower_csl_stencil.py | 4 ++-- xdsl/transforms/memref_to_dsd.py | 5 ++--- 9 files changed, 21 insertions(+), 26 deletions(-) diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py index d818adb499..49ad1bec6f 100644 --- a/xdsl/dialects/builtin.py +++ b/xdsl/dialects/builtin.py @@ -1975,7 +1975,6 @@ def constr( ) -AnyMemRefType: TypeAlias = MemRefType[Attribute] AnyMemRefTypeConstr = BaseAttr[MemRefType[Attribute]](MemRefType) diff --git a/xdsl/dialects/csl/csl.py b/xdsl/dialects/csl/csl.py index 0a35ee78f7..606e72174d 100644 --- a/xdsl/dialects/csl/csl.py +++ b/xdsl/dialects/csl/csl.py @@ -21,7 +21,6 @@ AnyFloatAttrConstr, AnyIntegerAttr, AnyIntegerAttrConstr, - AnyMemRefType, ArrayAttr, BoolAttr, ContainerType, @@ -1344,9 +1343,9 @@ def typcheck( sig_typ: Attribute | type[Attribute], ) -> bool: if isinstance(sig_typ, type): - return ( - sig_typ == DsdType and isa(op_typ, AnyMemRefType) - ) or isinstance(op_typ, sig_typ) + return (sig_typ == DsdType and isa(op_typ, MemRefType)) or isinstance( + op_typ, sig_typ + ) else: return op_typ == sig_typ diff --git a/xdsl/dialects/csl/csl_stencil.py b/xdsl/dialects/csl/csl_stencil.py index a5ace78bf6..24b2340fd9 100644 --- a/xdsl/dialects/csl/csl_stencil.py +++ b/xdsl/dialects/csl/csl_stencil.py @@ -6,7 +6,6 @@ from xdsl.dialects.builtin import ( AnyFloat, AnyIntegerAttr, - AnyMemRefType, AnyMemRefTypeConstr, AnyTensorTypeConstr, Float16Type, @@ -554,7 +553,7 @@ def parse(cls, parser: Parser): ], properties=props, ) - elif isattr(res_type, base(AnyMemRefType)): + elif isattr(res_type, base(MemRefType)): return cls.build( operands=[temp], result_types=[ diff --git a/xdsl/dialects/linalg.py b/xdsl/dialects/linalg.py index 065280c3be..9d3d4a6d87 100644 --- a/xdsl/dialects/linalg.py +++ b/xdsl/dialects/linalg.py @@ -12,7 +12,6 @@ from xdsl.dialects.builtin import ( AffineMapAttr, AnyFloat, - AnyMemRefType, AnyTensorType, ArrayAttr, DenseArrayBase, @@ -729,8 +728,8 @@ class TransposeOp(IRDLOperation): name = "linalg.transpose" - input = operand_def(base(AnyMemRefType) | base(AnyTensorType)) - init = operand_def(base(AnyMemRefType) | base(AnyTensorType)) + input = operand_def(base(MemRefType) | base(AnyTensorType)) + init = operand_def(base(MemRefType) | base(AnyTensorType)) result = var_result_def(AnyTensorType) permutation = attr_def(DenseArrayBase) @@ -1062,8 +1061,8 @@ class BroadcastOp(IRDLOperation): name = "linalg.broadcast" - input = operand_def(base(AnyMemRefType) | base(AnyTensorType)) - init = operand_def(base(AnyMemRefType) | base(AnyTensorType)) + input = operand_def(base(MemRefType) | base(AnyTensorType)) + init = operand_def(base(MemRefType) | base(AnyTensorType)) result = var_result_def(AnyTensorType) dimensions = attr_def(DenseArrayBase) diff --git a/xdsl/transforms/csl_stencil_bufferize.py b/xdsl/transforms/csl_stencil_bufferize.py index add9764e36..01cb99eb54 100644 --- a/xdsl/transforms/csl_stencil_bufferize.py +++ b/xdsl/transforms/csl_stencil_bufferize.py @@ -4,12 +4,12 @@ from xdsl.context import MLContext from xdsl.dialects import arith, bufferization, func, linalg, memref, stencil, tensor from xdsl.dialects.builtin import ( - AnyMemRefType, AnyTensorType, AnyTensorTypeConstr, DenseArrayBase, DenseIntOrFPElementsAttr, FunctionType, + MemRefType, ModuleOp, TensorType, i64, @@ -57,7 +57,7 @@ def to_tensor_op( op: SSAValue, writable: bool = False, restrict: bool = True ) -> bufferization.ToTensorOp: """Creates a `bufferization.to_tensor` operation.""" - assert isa(op.type, AnyMemRefType) + assert isa(op.type, MemRefType) return bufferization.ToTensorOp(op, restrict, writable) @@ -418,8 +418,8 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter, linalg_op := yld_arg.op.tensor.op, linalg.NamedOpBase | linalg.GenericOp, ) - or not isa(arg_t := arg.type, AnyMemRefType) - or not isa(yld_arg.type, AnyMemRefType) + or not isa(arg_t := arg.type, MemRefType) + or not isa(yld_arg.type, MemRefType) ): new_dest.append(arg) new_yield_args.append(yld_arg) diff --git a/xdsl/transforms/csl_stencil_to_csl_wrapper.py b/xdsl/transforms/csl_stencil_to_csl_wrapper.py index 3b376d797a..501cf9fdd3 100644 --- a/xdsl/transforms/csl_stencil_to_csl_wrapper.py +++ b/xdsl/transforms/csl_stencil_to_csl_wrapper.py @@ -5,7 +5,6 @@ from xdsl.context import MLContext from xdsl.dialects import arith, builtin, func, llvm, memref, stencil from xdsl.dialects.builtin import ( - AnyMemRefType, AnyMemRefTypeConstr, AnyTensorTypeConstr, ArrayAttr, @@ -13,6 +12,7 @@ IndexType, IntegerAttr, IntegerType, + MemRefType, ShapedType, Signedness, StringAttr, @@ -396,7 +396,7 @@ def match_and_rewrite(self, op: llvm.StoreOp, rewriter: PatternRewriter, /): or not (isinstance(start_call := end_call.arguments[0].owner, func.CallOp)) or not start_call.callee.string_value() == TIMER_START or not (wrapper := _get_module_wrapper(op)) - or not isa(op.ptr.type, AnyMemRefType) + or not isa(op.ptr.type, MemRefType) ): return diff --git a/xdsl/transforms/linalg_to_csl.py b/xdsl/transforms/linalg_to_csl.py index 683dce326e..f50c50617d 100644 --- a/xdsl/transforms/linalg_to_csl.py +++ b/xdsl/transforms/linalg_to_csl.py @@ -5,10 +5,10 @@ from xdsl.dialects.builtin import ( AnyFloatAttr, AnyIntegerAttr, - AnyMemRefType, DenseIntOrFPElementsAttr, Float16Type, Float32Type, + MemRefType, ModuleOp, ) from xdsl.dialects.csl import csl @@ -62,7 +62,7 @@ def transform_op( f16: type[csl.BuiltinDsdOp], f32: type[csl.BuiltinDsdOp], ): - if not isa(target_t := op.outputs.types[0], AnyMemRefType): + if not isa(target_t := op.outputs.types[0], MemRefType): return builtin = match_op_for_precision(target_t.get_element_type(), f16, f32) @@ -91,7 +91,7 @@ class ConvertLinalgGenericFMAPass(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: linalg.GenericOp, rewriter: PatternRewriter, /): - if not self.is_fma(op) or not isa(op.outputs.types[0], AnyMemRefType): + if not self.is_fma(op) or not isa(op.outputs.types[0], MemRefType): return # one of the factors must be a scalar const, which the csl function signatures require diff --git a/xdsl/transforms/lower_csl_stencil.py b/xdsl/transforms/lower_csl_stencil.py index 3f24f156a9..d2b4b5afd7 100644 --- a/xdsl/transforms/lower_csl_stencil.py +++ b/xdsl/transforms/lower_csl_stencil.py @@ -6,7 +6,6 @@ from xdsl.dialects.builtin import ( AffineMapAttr, AnyFloatAttr, - AnyMemRefType, DenseIntOrFPElementsAttr, Float16Type, Float32Type, @@ -14,6 +13,7 @@ FunctionType, IndexType, IntegerAttr, + MemRefType, ModuleOp, UnrealizedConversionCastOp, i16, @@ -445,7 +445,7 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter, return if ( - not isattr(accumulator.type, AnyMemRefType) + not isattr(accumulator.type, MemRefType) or not isinstance(op.accumulator, OpResult) or not isinstance(alloc := op.accumulator.op, memref.AllocOp) ): diff --git a/xdsl/transforms/memref_to_dsd.py b/xdsl/transforms/memref_to_dsd.py index 47720ea13e..a637f40033 100644 --- a/xdsl/transforms/memref_to_dsd.py +++ b/xdsl/transforms/memref_to_dsd.py @@ -7,7 +7,6 @@ from xdsl.dialects import arith, builtin, csl, memref from xdsl.dialects.builtin import ( AffineMapAttr, - AnyMemRefType, ArrayAttr, Float16Type, Float32Type, @@ -124,8 +123,8 @@ class LowerSubviewOpPass(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: memref.SubviewOp, rewriter: PatternRewriter, /): - assert isa(op.source.type, AnyMemRefType) - assert isa(op.result.type, AnyMemRefType) + assert isa(op.source.type, MemRefType) + assert isa(op.result.type, MemRefType) if len(op.result.type.get_shape()) == 1 and len(op.source.type.get_shape()) > 1: # 1d subview onto a nd memref