Skip to content

Commit

Permalink
dialects: (builtin) remove AnyIntegerAttr (#3843)
Browse files Browse the repository at this point in the history
s/AnyIntegerAttr/IntegerAttr/g
  • Loading branch information
alexarice authored Feb 5, 2025
1 parent 842dc07 commit 8ee6ed9
Show file tree
Hide file tree
Showing 30 changed files with 155 additions and 176 deletions.
3 changes: 1 addition & 2 deletions tests/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from xdsl.context import MLContext
from xdsl.dialects.builtin import (
AnyFloatAttr,
AnyIntegerAttr,
ArrayAttr,
Builtin,
DictionaryAttr,
Expand Down Expand Up @@ -823,7 +822,7 @@ def test_parse_number(
],
)
def test_parse_optional_builtin_int_or_float_attr(
text: str, expected_value: AnyIntegerAttr | AnyFloatAttr | None
text: str, expected_value: IntegerAttr | AnyFloatAttr | None
):
parser = Parser(MLContext(), text)
if expected_value is None:
Expand Down
3 changes: 1 addition & 2 deletions tests/test_traits.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from xdsl.dialects import test
from xdsl.dialects.builtin import (
DYNAMIC_INDEX,
AnyIntegerAttr,
AnyTensorTypeConstr,
AnyUnrankedMemRefTypeConstr,
AnyUnrankedTensorTypeConstr,
Expand Down Expand Up @@ -306,7 +305,7 @@ class NoSymNameOp(IRDLOperation):
class SymNameWrongTypeOp(IRDLOperation):
name = "wrong_sym_name_type"

sym_name = attr_def(AnyIntegerAttr)
sym_name = attr_def(IntegerAttr)
traits = traits_def(SymbolOpInterface())

op1 = SymNameWrongTypeOp(
Expand Down
9 changes: 4 additions & 5 deletions xdsl/dialects/accfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import cast

from xdsl.dialects.builtin import (
AnyIntegerAttr,
ArrayAttr,
DictionaryAttr,
IntegerAttr,
Expand Down Expand Up @@ -422,16 +421,16 @@ def verify_(self) -> None:
def field_names(self) -> tuple[str, ...]:
return tuple(self.fields.data.keys())

def field_items(self) -> Iterable[tuple[str, AnyIntegerAttr]]:
def field_items(self) -> Iterable[tuple[str, IntegerAttr]]:
for name, val in self.fields.data.items():
yield name, cast(AnyIntegerAttr, val)
yield name, cast(IntegerAttr, val)

def launch_field_names(self) -> tuple[str, ...]:
return tuple(self.launch_fields.data.keys())

def launch_field_items(self) -> Iterable[tuple[str, AnyIntegerAttr]]:
def launch_field_items(self) -> Iterable[tuple[str, IntegerAttr]]:
for name, val in self.launch_fields.data.items():
yield name, cast(AnyIntegerAttr, val)
yield name, cast(IntegerAttr, val)


@irdl_op_definition
Expand Down
5 changes: 2 additions & 3 deletions xdsl/dialects/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from xdsl.dialects.builtin import (
AffineMapAttr,
AffineSetAttr,
AnyIntegerAttr,
ArrayAttr,
ContainerType,
DenseIntOrFPElementsAttr,
Expand Down Expand Up @@ -121,7 +120,7 @@ class ForOp(IRDLOperation):

lowerBoundMap = prop_def(AffineMapAttr)
upperBoundMap = prop_def(AffineMapAttr)
step = prop_def(AnyIntegerAttr)
step = prop_def(IntegerAttr)

body = region_def()

Expand Down Expand Up @@ -168,7 +167,7 @@ def from_region(
lower_bound: int | AffineMapAttr,
upper_bound: int | AffineMapAttr,
region: Region,
step: int | AnyIntegerAttr = 1,
step: int | IntegerAttr = 1,
) -> ForOp:
if isinstance(lower_bound, int):
lower_bound = AffineMapAttr(
Expand Down
45 changes: 22 additions & 23 deletions xdsl/dialects/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from xdsl.dialects.builtin import (
AnyFloat,
AnyFloatConstr,
AnyIntegerAttr,
ContainerOf,
DenseIntOrFPElementsAttr,
Float16Type,
Expand Down Expand Up @@ -145,7 +144,7 @@ class ConstantOp(IRDLOperation):
@overload
def __init__(
self,
value: AnyIntegerAttr | FloatAttr[AnyFloat] | DenseIntOrFPElementsAttr,
value: IntegerAttr | FloatAttr[AnyFloat] | DenseIntOrFPElementsAttr,
value_type: None = None,
) -> None: ...

Expand All @@ -154,11 +153,11 @@ def __init__(self, value: Attribute, value_type: Attribute) -> None: ...

def __init__(
self,
value: AnyIntegerAttr | FloatAttr[AnyFloat] | Attribute,
value: IntegerAttr | FloatAttr[AnyFloat] | Attribute,
value_type: Attribute | None = None,
):
if value_type is None:
value = cast(AnyIntegerAttr | FloatAttr[AnyFloat], value)
value = cast(IntegerAttr | FloatAttr[AnyFloat], value)
value_type = value.type
super().__init__(
operands=[], result_types=[value_type], properties={"value": value}
Expand Down Expand Up @@ -207,7 +206,7 @@ def py_operation(lhs: int, rhs: int) -> int | None:
return None

@staticmethod
def is_right_zero(attr: AnyIntegerAttr) -> bool:
def is_right_zero(attr: IntegerAttr) -> bool:
"""
Returns True only when 'attr' is a right zero for the operation
https://en.wikipedia.org/wiki/Absorbing_element
Expand All @@ -218,7 +217,7 @@ def is_right_zero(attr: AnyIntegerAttr) -> bool:
return False

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
def is_right_unit(attr: IntegerAttr) -> bool:
"""
Return True only when 'attr' is a right unit/identity for the operation
https://en.wikipedia.org/wiki/Identity_element
Expand Down Expand Up @@ -379,7 +378,7 @@ def py_operation(lhs: int, rhs: int) -> int | None:
return lhs + rhs

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
def is_right_unit(attr: IntegerAttr) -> bool:
return attr.value.data == 0


Expand Down Expand Up @@ -462,11 +461,11 @@ def py_operation(lhs: int, rhs: int) -> int | None:
return lhs * rhs

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
def is_right_unit(attr: IntegerAttr) -> bool:
return attr == IntegerAttr(1, attr.type)

@staticmethod
def is_right_zero(attr: AnyIntegerAttr) -> bool:
def is_right_zero(attr: IntegerAttr) -> bool:
return attr.value.data == 0


Expand Down Expand Up @@ -524,7 +523,7 @@ def py_operation(lhs: int, rhs: int) -> int | None:
return lhs - rhs

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
def is_right_unit(attr: IntegerAttr) -> bool:
return attr.value.data == 0


Expand Down Expand Up @@ -555,7 +554,7 @@ class DivUIOp(SignlessIntegerBinaryOperation):
)

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
def is_right_unit(attr: IntegerAttr) -> bool:
return attr == IntegerAttr(1, attr.type)


Expand All @@ -574,7 +573,7 @@ class DivSIOp(SignlessIntegerBinaryOperation):
)

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
def is_right_unit(attr: IntegerAttr) -> bool:
return attr == IntegerAttr(1, attr.type)


Expand All @@ -591,7 +590,7 @@ class FloorDivSIOp(SignlessIntegerBinaryOperation):
)

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
def is_right_unit(attr: IntegerAttr) -> bool:
return attr == IntegerAttr(1, attr.type)


Expand All @@ -604,7 +603,7 @@ class CeilDivSIOp(SignlessIntegerBinaryOperation):
)

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
def is_right_unit(attr: IntegerAttr) -> bool:
return attr == IntegerAttr(1, attr.type)


Expand All @@ -618,7 +617,7 @@ class CeilDivUIOp(SignlessIntegerBinaryOperation):
)

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
def is_right_unit(attr: IntegerAttr) -> bool:
return attr == IntegerAttr(1, attr.type)


Expand Down Expand Up @@ -677,7 +676,7 @@ def py_operation(lhs: int, rhs: int) -> int | None:
return lhs & rhs

@staticmethod
def is_right_zero(attr: AnyIntegerAttr) -> bool:
def is_right_zero(attr: IntegerAttr) -> bool:
return attr.value.data == 0


Expand All @@ -696,7 +695,7 @@ def py_operation(lhs: int, rhs: int) -> int | None:
return lhs | rhs

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
def is_right_unit(attr: IntegerAttr) -> bool:
return attr.value.data == 0


Expand All @@ -715,7 +714,7 @@ def py_operation(lhs: int, rhs: int) -> int | None:
return lhs ^ rhs

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
def is_right_unit(attr: IntegerAttr) -> bool:
return attr.value.data == 0


Expand All @@ -733,7 +732,7 @@ class ShLIOp(SignlessIntegerBinaryOperationWithOverflow):
)

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
def is_right_unit(attr: IntegerAttr) -> bool:
return attr.value.data == 0


Expand All @@ -752,7 +751,7 @@ class ShRUIOp(SignlessIntegerBinaryOperation):
)

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
def is_right_unit(attr: IntegerAttr) -> bool:
return attr.value.data == 0


Expand All @@ -772,7 +771,7 @@ class ShRSIOp(SignlessIntegerBinaryOperation):
)

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
def is_right_unit(attr: IntegerAttr) -> bool:
return attr.value.data == 0


Expand Down Expand Up @@ -853,7 +852,7 @@ class CmpiOp(ComparisonOperation):
"""

name = "arith.cmpi"
predicate = prop_def(AnyIntegerAttr)
predicate = prop_def(IntegerAttr)
lhs = operand_def(signlessIntegerLike)
rhs = operand_def(signlessIntegerLike)
result = result_def(IntegerType(1))
Expand Down Expand Up @@ -945,7 +944,7 @@ class CmpfOp(ComparisonOperation):
"""

name = "arith.cmpf"
predicate = prop_def(AnyIntegerAttr)
predicate = prop_def(IntegerAttr)
lhs = operand_def(floatingPointLike)
rhs = operand_def(floatingPointLike)
fastmath = prop_def(FastMathFlagsAttr, default_value=FastMathFlagsAttr("none"))
Expand Down
11 changes: 5 additions & 6 deletions xdsl/dialects/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,6 @@ def unpack(
return tuple(IntegerAttr(value, type) for value in type.unpack(buffer, num))


AnyIntegerAttr: TypeAlias = IntegerAttr[IntegerType | IndexType]
BoolAttr: TypeAlias = IntegerAttr[Annotated[IntegerType, IntegerType(1)]]


Expand Down Expand Up @@ -1352,13 +1351,13 @@ def iter_values(self) -> Iterator[float] | Iterator[int]:
def get_values(self) -> tuple[int, ...] | tuple[float, ...]:
return self.elt_type.unpack(self.data.data, len(self))

def iter_attrs(self) -> Iterator[AnyIntegerAttr] | Iterator[AnyFloatAttr]:
def iter_attrs(self) -> Iterator[IntegerAttr] | Iterator[AnyFloatAttr]:
if isinstance(self.elt_type, IntegerType):
return IntegerAttr.iter_unpack(self.elt_type, self.data.data)
else:
return FloatAttr.iter_unpack(self.elt_type, self.data.data)

def get_attrs(self) -> tuple[AnyIntegerAttr, ...] | tuple[AnyFloatAttr, ...]:
def get_attrs(self) -> tuple[IntegerAttr, ...] | tuple[AnyFloatAttr, ...]:
if isinstance(self.elt_type, IntegerType):
return IntegerAttr.unpack(self.elt_type, self.data.data, len(self))
else:
Expand Down Expand Up @@ -2180,7 +2179,7 @@ def from_list(
| RankedStructure[IntegerType]
| RankedStructure[IndexType]
),
data: Sequence[int | float] | Sequence[AnyIntegerAttr] | Sequence[AnyFloatAttr],
data: Sequence[int | float] | Sequence[IntegerAttr] | Sequence[AnyFloatAttr],
) -> DenseIntOrFPElementsAttr:
# zero rank type should only hold 1 value
if not type.get_shape() and len(data) != 1:
Expand Down Expand Up @@ -2249,7 +2248,7 @@ def get_values(self) -> Sequence[int] | Sequence[float]:
"""
return self.get_element_type().unpack(self.data.data, len(self))

def iter_attrs(self) -> Iterator[AnyIntegerAttr] | Iterator[AnyFloatAttr]:
def iter_attrs(self) -> Iterator[IntegerAttr] | Iterator[AnyFloatAttr]:
"""
Return an iterator over all elements of the dense attribute in their relevant
attribute representation (IntegerAttr / FloatAttr)
Expand All @@ -2259,7 +2258,7 @@ def iter_attrs(self) -> Iterator[AnyIntegerAttr] | Iterator[AnyFloatAttr]:
else:
return FloatAttr.iter_unpack(eltype, self.data.data)

def get_attrs(self) -> Sequence[AnyIntegerAttr] | Sequence[AnyFloatAttr]:
def get_attrs(self) -> Sequence[IntegerAttr] | Sequence[AnyFloatAttr]:
"""
Return all elements of the dense attribute in their relevant
attribute representation (IntegerAttr / FloatAttr)
Expand Down
3 changes: 1 addition & 2 deletions xdsl/dialects/csl/csl.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
AffineMapAttr,
AnyFloatAttr,
AnyFloatAttrConstr,
AnyIntegerAttr,
ArrayAttr,
BoolAttr,
ContainerType,
Expand Down Expand Up @@ -418,7 +417,7 @@ def get_element_type(self) -> TypeAttribute:

QueueIdAttr: TypeAlias = IntegerAttr[Annotated[IntegerType, IntegerType(3)]]

ParamAttr: TypeAlias = AnyFloatAttr | AnyIntegerAttr
ParamAttr: TypeAlias = AnyFloatAttr | IntegerAttr
ParamAttrConstr = AnyFloatAttrConstr | IntegerAttr.constr()


Expand Down
8 changes: 4 additions & 4 deletions xdsl/dialects/csl/csl_stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
from xdsl.dialects import builtin, memref, stencil
from xdsl.dialects.builtin import (
AnyFloat,
AnyIntegerAttr,
AnyTensorTypeConstr,
Float16Type,
Float32Type,
FloatAttr,
IndexType,
IntegerAttr,
MemRefType,
TensorType,
)
Expand Down Expand Up @@ -138,15 +138,15 @@ class PrefetchOp(IRDLOperation):

topo = prop_def(dmp.RankTopoAttr)

num_chunks = prop_def(AnyIntegerAttr)
num_chunks = prop_def(IntegerAttr)

result = result_def(MemRefType.constr() | AnyTensorTypeConstr)

def __init__(
self,
input_stencil: SSAValue | Operation,
topo: dmp.RankTopoAttr,
num_chunks: AnyIntegerAttr,
num_chunks: IntegerAttr,
swaps: Sequence[ExchangeDeclarationAttr],
result_type: memref.MemRefType[Attribute] | TensorType[Attribute] | None = None,
):
Expand Down Expand Up @@ -241,7 +241,7 @@ class ApplyOp(IRDLOperation):

topo = prop_def(dmp.RankTopoAttr)

num_chunks = prop_def(AnyIntegerAttr)
num_chunks = prop_def(IntegerAttr)

bounds = opt_prop_def(stencil.StencilBoundsAttr)

Expand Down
Loading

0 comments on commit 8ee6ed9

Please sign in to comment.