Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dialects: (builtin) remove AnyIntegerAttr #3843

Merged
merged 1 commit into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions tests/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from xdsl.context import MLContext
from xdsl.dialects.builtin import (
AnyFloatAttr,
AnyIntegerAttr,
ArrayAttr,
Builtin,
DictionaryAttr,
Expand Down Expand Up @@ -823,7 +822,7 @@ def test_parse_number(
],
)
def test_parse_optional_builtin_int_or_float_attr(
text: str, expected_value: AnyIntegerAttr | AnyFloatAttr | None
text: str, expected_value: IntegerAttr | AnyFloatAttr | None
):
parser = Parser(MLContext(), text)
if expected_value is None:
Expand Down
3 changes: 1 addition & 2 deletions tests/test_traits.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from xdsl.dialects import test
from xdsl.dialects.builtin import (
DYNAMIC_INDEX,
AnyIntegerAttr,
AnyTensorTypeConstr,
AnyUnrankedMemRefTypeConstr,
AnyUnrankedTensorTypeConstr,
Expand Down Expand Up @@ -306,7 +305,7 @@ class NoSymNameOp(IRDLOperation):
class SymNameWrongTypeOp(IRDLOperation):
name = "wrong_sym_name_type"

sym_name = attr_def(AnyIntegerAttr)
sym_name = attr_def(IntegerAttr)
traits = traits_def(SymbolOpInterface())

op1 = SymNameWrongTypeOp(
Expand Down
9 changes: 4 additions & 5 deletions xdsl/dialects/accfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import cast

from xdsl.dialects.builtin import (
AnyIntegerAttr,
ArrayAttr,
DictionaryAttr,
IntegerAttr,
Expand Down Expand Up @@ -422,16 +421,16 @@ def verify_(self) -> None:
def field_names(self) -> tuple[str, ...]:
return tuple(self.fields.data.keys())

def field_items(self) -> Iterable[tuple[str, AnyIntegerAttr]]:
def field_items(self) -> Iterable[tuple[str, IntegerAttr]]:
for name, val in self.fields.data.items():
yield name, cast(AnyIntegerAttr, val)
yield name, cast(IntegerAttr, val)

def launch_field_names(self) -> tuple[str, ...]:
return tuple(self.launch_fields.data.keys())

def launch_field_items(self) -> Iterable[tuple[str, AnyIntegerAttr]]:
def launch_field_items(self) -> Iterable[tuple[str, IntegerAttr]]:
for name, val in self.launch_fields.data.items():
yield name, cast(AnyIntegerAttr, val)
yield name, cast(IntegerAttr, val)


@irdl_op_definition
Expand Down
5 changes: 2 additions & 3 deletions xdsl/dialects/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from xdsl.dialects.builtin import (
AffineMapAttr,
AffineSetAttr,
AnyIntegerAttr,
ArrayAttr,
ContainerType,
DenseIntOrFPElementsAttr,
Expand Down Expand Up @@ -121,7 +120,7 @@ class ForOp(IRDLOperation):

lowerBoundMap = prop_def(AffineMapAttr)
upperBoundMap = prop_def(AffineMapAttr)
step = prop_def(AnyIntegerAttr)
step = prop_def(IntegerAttr)

body = region_def()

Expand Down Expand Up @@ -168,7 +167,7 @@ def from_region(
lower_bound: int | AffineMapAttr,
upper_bound: int | AffineMapAttr,
region: Region,
step: int | AnyIntegerAttr = 1,
step: int | IntegerAttr = 1,
) -> ForOp:
if isinstance(lower_bound, int):
lower_bound = AffineMapAttr(
Expand Down
45 changes: 22 additions & 23 deletions xdsl/dialects/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from xdsl.dialects.builtin import (
AnyFloat,
AnyFloatConstr,
AnyIntegerAttr,
ContainerOf,
DenseIntOrFPElementsAttr,
Float16Type,
Expand Down Expand Up @@ -145,7 +144,7 @@ class ConstantOp(IRDLOperation):
@overload
def __init__(
self,
value: AnyIntegerAttr | FloatAttr[AnyFloat] | DenseIntOrFPElementsAttr,
value: IntegerAttr | FloatAttr[AnyFloat] | DenseIntOrFPElementsAttr,
value_type: None = None,
) -> None: ...

Expand All @@ -154,11 +153,11 @@ def __init__(self, value: Attribute, value_type: Attribute) -> None: ...

def __init__(
self,
value: AnyIntegerAttr | FloatAttr[AnyFloat] | Attribute,
value: IntegerAttr | FloatAttr[AnyFloat] | Attribute,
value_type: Attribute | None = None,
):
if value_type is None:
value = cast(AnyIntegerAttr | FloatAttr[AnyFloat], value)
value = cast(IntegerAttr | FloatAttr[AnyFloat], value)
value_type = value.type
super().__init__(
operands=[], result_types=[value_type], properties={"value": value}
Expand Down Expand Up @@ -207,7 +206,7 @@ def py_operation(lhs: int, rhs: int) -> int | None:
return None

@staticmethod
def is_right_zero(attr: AnyIntegerAttr) -> bool:
def is_right_zero(attr: IntegerAttr) -> bool:
"""
Returns True only when 'attr' is a right zero for the operation
https://en.wikipedia.org/wiki/Absorbing_element
Expand All @@ -218,7 +217,7 @@ def is_right_zero(attr: AnyIntegerAttr) -> bool:
return False

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
def is_right_unit(attr: IntegerAttr) -> bool:
"""
Return True only when 'attr' is a right unit/identity for the operation
https://en.wikipedia.org/wiki/Identity_element
Expand Down Expand Up @@ -379,7 +378,7 @@ def py_operation(lhs: int, rhs: int) -> int | None:
return lhs + rhs

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
def is_right_unit(attr: IntegerAttr) -> bool:
return attr.value.data == 0


Expand Down Expand Up @@ -462,11 +461,11 @@ def py_operation(lhs: int, rhs: int) -> int | None:
return lhs * rhs

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
def is_right_unit(attr: IntegerAttr) -> bool:
return attr == IntegerAttr(1, attr.type)

@staticmethod
def is_right_zero(attr: AnyIntegerAttr) -> bool:
def is_right_zero(attr: IntegerAttr) -> bool:
return attr.value.data == 0


Expand Down Expand Up @@ -524,7 +523,7 @@ def py_operation(lhs: int, rhs: int) -> int | None:
return lhs - rhs

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
def is_right_unit(attr: IntegerAttr) -> bool:
return attr.value.data == 0


Expand Down Expand Up @@ -555,7 +554,7 @@ class DivUIOp(SignlessIntegerBinaryOperation):
)

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
def is_right_unit(attr: IntegerAttr) -> bool:
return attr == IntegerAttr(1, attr.type)


Expand All @@ -574,7 +573,7 @@ class DivSIOp(SignlessIntegerBinaryOperation):
)

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
def is_right_unit(attr: IntegerAttr) -> bool:
return attr == IntegerAttr(1, attr.type)


Expand All @@ -591,7 +590,7 @@ class FloorDivSIOp(SignlessIntegerBinaryOperation):
)

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
def is_right_unit(attr: IntegerAttr) -> bool:
return attr == IntegerAttr(1, attr.type)


Expand All @@ -604,7 +603,7 @@ class CeilDivSIOp(SignlessIntegerBinaryOperation):
)

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
def is_right_unit(attr: IntegerAttr) -> bool:
return attr == IntegerAttr(1, attr.type)


Expand All @@ -618,7 +617,7 @@ class CeilDivUIOp(SignlessIntegerBinaryOperation):
)

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
def is_right_unit(attr: IntegerAttr) -> bool:
return attr == IntegerAttr(1, attr.type)


Expand Down Expand Up @@ -677,7 +676,7 @@ def py_operation(lhs: int, rhs: int) -> int | None:
return lhs & rhs

@staticmethod
def is_right_zero(attr: AnyIntegerAttr) -> bool:
def is_right_zero(attr: IntegerAttr) -> bool:
return attr.value.data == 0


Expand All @@ -696,7 +695,7 @@ def py_operation(lhs: int, rhs: int) -> int | None:
return lhs | rhs

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
def is_right_unit(attr: IntegerAttr) -> bool:
return attr.value.data == 0


Expand All @@ -715,7 +714,7 @@ def py_operation(lhs: int, rhs: int) -> int | None:
return lhs ^ rhs

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
def is_right_unit(attr: IntegerAttr) -> bool:
return attr.value.data == 0


Expand All @@ -733,7 +732,7 @@ class ShLIOp(SignlessIntegerBinaryOperationWithOverflow):
)

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
def is_right_unit(attr: IntegerAttr) -> bool:
return attr.value.data == 0


Expand All @@ -752,7 +751,7 @@ class ShRUIOp(SignlessIntegerBinaryOperation):
)

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
def is_right_unit(attr: IntegerAttr) -> bool:
return attr.value.data == 0


Expand All @@ -772,7 +771,7 @@ class ShRSIOp(SignlessIntegerBinaryOperation):
)

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
def is_right_unit(attr: IntegerAttr) -> bool:
return attr.value.data == 0


Expand Down Expand Up @@ -853,7 +852,7 @@ class CmpiOp(ComparisonOperation):
"""

name = "arith.cmpi"
predicate = prop_def(AnyIntegerAttr)
predicate = prop_def(IntegerAttr)
lhs = operand_def(signlessIntegerLike)
rhs = operand_def(signlessIntegerLike)
result = result_def(IntegerType(1))
Expand Down Expand Up @@ -945,7 +944,7 @@ class CmpfOp(ComparisonOperation):
"""

name = "arith.cmpf"
predicate = prop_def(AnyIntegerAttr)
predicate = prop_def(IntegerAttr)
lhs = operand_def(floatingPointLike)
rhs = operand_def(floatingPointLike)
fastmath = prop_def(FastMathFlagsAttr, default_value=FastMathFlagsAttr("none"))
Expand Down
11 changes: 5 additions & 6 deletions xdsl/dialects/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,6 @@ def unpack(
return tuple(IntegerAttr(value, type) for value in type.unpack(buffer, num))


AnyIntegerAttr: TypeAlias = IntegerAttr[IntegerType | IndexType]
BoolAttr: TypeAlias = IntegerAttr[Annotated[IntegerType, IntegerType(1)]]


Expand Down Expand Up @@ -1352,13 +1351,13 @@ def iter_values(self) -> Iterator[float] | Iterator[int]:
def get_values(self) -> tuple[int, ...] | tuple[float, ...]:
return self.elt_type.unpack(self.data.data, len(self))

def iter_attrs(self) -> Iterator[AnyIntegerAttr] | Iterator[AnyFloatAttr]:
def iter_attrs(self) -> Iterator[IntegerAttr] | Iterator[AnyFloatAttr]:
if isinstance(self.elt_type, IntegerType):
return IntegerAttr.iter_unpack(self.elt_type, self.data.data)
else:
return FloatAttr.iter_unpack(self.elt_type, self.data.data)

def get_attrs(self) -> tuple[AnyIntegerAttr, ...] | tuple[AnyFloatAttr, ...]:
def get_attrs(self) -> tuple[IntegerAttr, ...] | tuple[AnyFloatAttr, ...]:
if isinstance(self.elt_type, IntegerType):
return IntegerAttr.unpack(self.elt_type, self.data.data, len(self))
else:
Expand Down Expand Up @@ -2180,7 +2179,7 @@ def from_list(
| RankedStructure[IntegerType]
| RankedStructure[IndexType]
),
data: Sequence[int | float] | Sequence[AnyIntegerAttr] | Sequence[AnyFloatAttr],
data: Sequence[int | float] | Sequence[IntegerAttr] | Sequence[AnyFloatAttr],
) -> DenseIntOrFPElementsAttr:
# zero rank type should only hold 1 value
if not type.get_shape() and len(data) != 1:
Expand Down Expand Up @@ -2249,7 +2248,7 @@ def get_values(self) -> Sequence[int] | Sequence[float]:
"""
return self.get_element_type().unpack(self.data.data, len(self))

def iter_attrs(self) -> Iterator[AnyIntegerAttr] | Iterator[AnyFloatAttr]:
def iter_attrs(self) -> Iterator[IntegerAttr] | Iterator[AnyFloatAttr]:
"""
Return an iterator over all elements of the dense attribute in their relevant
attribute representation (IntegerAttr / FloatAttr)
Expand All @@ -2259,7 +2258,7 @@ def iter_attrs(self) -> Iterator[AnyIntegerAttr] | Iterator[AnyFloatAttr]:
else:
return FloatAttr.iter_unpack(eltype, self.data.data)

def get_attrs(self) -> Sequence[AnyIntegerAttr] | Sequence[AnyFloatAttr]:
def get_attrs(self) -> Sequence[IntegerAttr] | Sequence[AnyFloatAttr]:
"""
Return all elements of the dense attribute in their relevant
attribute representation (IntegerAttr / FloatAttr)
Expand Down
3 changes: 1 addition & 2 deletions xdsl/dialects/csl/csl.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
AffineMapAttr,
AnyFloatAttr,
AnyFloatAttrConstr,
AnyIntegerAttr,
ArrayAttr,
BoolAttr,
ContainerType,
Expand Down Expand Up @@ -418,7 +417,7 @@ def get_element_type(self) -> TypeAttribute:

QueueIdAttr: TypeAlias = IntegerAttr[Annotated[IntegerType, IntegerType(3)]]

ParamAttr: TypeAlias = AnyFloatAttr | AnyIntegerAttr
ParamAttr: TypeAlias = AnyFloatAttr | IntegerAttr
ParamAttrConstr = AnyFloatAttrConstr | IntegerAttr.constr()


Expand Down
8 changes: 4 additions & 4 deletions xdsl/dialects/csl/csl_stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
from xdsl.dialects import builtin, memref, stencil
from xdsl.dialects.builtin import (
AnyFloat,
AnyIntegerAttr,
AnyTensorTypeConstr,
Float16Type,
Float32Type,
FloatAttr,
IndexType,
IntegerAttr,
MemRefType,
TensorType,
)
Expand Down Expand Up @@ -138,15 +138,15 @@ class PrefetchOp(IRDLOperation):

topo = prop_def(dmp.RankTopoAttr)

num_chunks = prop_def(AnyIntegerAttr)
num_chunks = prop_def(IntegerAttr)

result = result_def(MemRefType.constr() | AnyTensorTypeConstr)

def __init__(
self,
input_stencil: SSAValue | Operation,
topo: dmp.RankTopoAttr,
num_chunks: AnyIntegerAttr,
num_chunks: IntegerAttr,
swaps: Sequence[ExchangeDeclarationAttr],
result_type: memref.MemRefType[Attribute] | TensorType[Attribute] | None = None,
):
Expand Down Expand Up @@ -241,7 +241,7 @@ class ApplyOp(IRDLOperation):

topo = prop_def(dmp.RankTopoAttr)

num_chunks = prop_def(AnyIntegerAttr)
num_chunks = prop_def(IntegerAttr)

bounds = opt_prop_def(stencil.StencilBoundsAttr)

Expand Down
Loading