Skip to content

Commit

Permalink
dialects: (builtin) remove AnyFloatAttr(Constr)? (#3844)
Browse files Browse the repository at this point in the history
Adds a default type to `FloatAttr` and removes `AnyFloatAttr` and
`AnyFloatAttrConstr`
  • Loading branch information
alexarice authored and emmau678 committed Feb 6, 2025
1 parent 51db69f commit 4504a60
Show file tree
Hide file tree
Showing 12 changed files with 51 additions and 57 deletions.
26 changes: 13 additions & 13 deletions tests/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@

from xdsl.context import MLContext
from xdsl.dialects.builtin import (
AnyFloatAttr,
ArrayAttr,
Builtin,
DictionaryAttr,
FloatAttr,
IntAttr,
IntegerAttr,
IntegerType,
Expand Down Expand Up @@ -807,22 +807,22 @@ def test_parse_number(
("24: i32", IntegerAttr(24, 32)),
("0: index", IntegerAttr.from_index_int_value(0)),
("-64: i64", IntegerAttr(-64, 64)),
("-64.4: f64", AnyFloatAttr(-64.4, 64)),
("32.4: f32", AnyFloatAttr(32.4, 32)),
("0x7e00 : f16", AnyFloatAttr(float("nan"), 16)),
("0x7c00 : f16", AnyFloatAttr(float("inf"), 16)),
("0xfc00 : f16", AnyFloatAttr(float("-inf"), 16)),
("0x7fc00000 : f32", AnyFloatAttr(float("nan"), 32)),
("0x7f800000 : f32", AnyFloatAttr(float("inf"), 32)),
("0xff800000 : f32", AnyFloatAttr(float("-inf"), 32)),
("0x7ff8000000000000 : f64", AnyFloatAttr(float("nan"), 64)),
("0x7ff0000000000000 : f64", AnyFloatAttr(float("inf"), 64)),
("0xfff0000000000000 : f64", AnyFloatAttr(float("-inf"), 64)),
("-64.4: f64", FloatAttr(-64.4, 64)),
("32.4: f32", FloatAttr(32.4, 32)),
("0x7e00 : f16", FloatAttr(float("nan"), 16)),
("0x7c00 : f16", FloatAttr(float("inf"), 16)),
("0xfc00 : f16", FloatAttr(float("-inf"), 16)),
("0x7fc00000 : f32", FloatAttr(float("nan"), 32)),
("0x7f800000 : f32", FloatAttr(float("inf"), 32)),
("0xff800000 : f32", FloatAttr(float("-inf"), 32)),
("0x7ff8000000000000 : f64", FloatAttr(float("nan"), 64)),
("0x7ff0000000000000 : f64", FloatAttr(float("inf"), 64)),
("0xfff0000000000000 : f64", FloatAttr(float("-inf"), 64)),
# ("3 : f64", None), # todo this fails in mlir-opt but not in xdsl
],
)
def test_parse_optional_builtin_int_or_float_attr(
text: str, expected_value: IntegerAttr | AnyFloatAttr | None
text: str, expected_value: IntegerAttr | FloatAttr | None
):
parser = Parser(MLContext(), text)
if expected_value is None:
Expand Down
3 changes: 1 addition & 2 deletions tests/test_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from xdsl.dialects.arith import AddiOp, Arith, ConstantOp
from xdsl.dialects.builtin import (
AnyFloat,
AnyFloatAttr,
Builtin,
FloatAttr,
FunctionType,
Expand Down Expand Up @@ -814,7 +813,7 @@ def _test_float_attr(value: float, type: AnyFloat):
def test_float_attr_specials():
printer = Printer()

def _test_attr_print(expected: str, attr: AnyFloatAttr):
def _test_attr_print(expected: str, attr: FloatAttr):
io = StringIO()
printer.stream = io
printer.print_attribute(attr)
Expand Down
28 changes: 13 additions & 15 deletions xdsl/dialects/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,7 +905,9 @@ def __hash__(self):
return hash(self.data)


_FloatAttrType = TypeVar("_FloatAttrType", bound=AnyFloat, covariant=True)
_FloatAttrType = TypeVar(
"_FloatAttrType", bound=AnyFloat, covariant=True, default=AnyFloat
)
_FloatAttrTypeInvT = TypeVar("_FloatAttrTypeInvT", bound=AnyFloat)


Expand Down Expand Up @@ -980,10 +982,6 @@ def unpack(
return tuple(FloatAttr(value, type) for value in type.unpack(buffer, num))


AnyFloatAttr: TypeAlias = FloatAttr[AnyFloat]
AnyFloatAttrConstr: BaseAttr[AnyFloatAttr] = BaseAttr(FloatAttr)


@irdl_attr_definition
class ComplexType(ParametrizedAttribute, TypeAttribute):
name = "complex"
Expand Down Expand Up @@ -1351,13 +1349,13 @@ def iter_values(self) -> Iterator[float] | Iterator[int]:
def get_values(self) -> tuple[int, ...] | tuple[float, ...]:
return self.elt_type.unpack(self.data.data, len(self))

def iter_attrs(self) -> Iterator[IntegerAttr] | Iterator[AnyFloatAttr]:
def iter_attrs(self) -> Iterator[IntegerAttr] | Iterator[FloatAttr]:
if isinstance(self.elt_type, IntegerType):
return IntegerAttr.iter_unpack(self.elt_type, self.data.data)
else:
return FloatAttr.iter_unpack(self.elt_type, self.data.data)

def get_attrs(self) -> tuple[IntegerAttr, ...] | tuple[AnyFloatAttr, ...]:
def get_attrs(self) -> tuple[IntegerAttr, ...] | tuple[FloatAttr, ...]:
if isinstance(self.elt_type, IntegerType):
return IntegerAttr.unpack(self.elt_type, self.data.data, len(self))
else:
Expand Down Expand Up @@ -2134,10 +2132,10 @@ def create_dense_int(
@staticmethod
def create_dense_float(
type: RankedStructure[AnyFloat],
data: Sequence[int | float] | Sequence[AnyFloatAttr],
data: Sequence[int | float] | Sequence[FloatAttr],
) -> DenseIntOrFPElementsAttr:
if len(data) and isa(data[0], AnyFloatAttr):
data = [el.value.data for el in cast(Sequence[AnyFloatAttr], data)]
if len(data) and isa(data[0], FloatAttr):
data = [el.value.data for el in cast(Sequence[FloatAttr], data)]
else:
data = cast(Sequence[float], data)

Expand Down Expand Up @@ -2168,7 +2166,7 @@ def from_list(
| RankedStructure[IntegerType]
| RankedStructure[IndexType]
),
data: Sequence[int | float] | Sequence[AnyFloatAttr],
data: Sequence[int | float] | Sequence[FloatAttr],
) -> DenseIntOrFPElementsAttr: ...

@staticmethod
Expand All @@ -2179,7 +2177,7 @@ def from_list(
| RankedStructure[IntegerType]
| RankedStructure[IndexType]
),
data: Sequence[int | float] | Sequence[IntegerAttr] | Sequence[AnyFloatAttr],
data: Sequence[int | float] | Sequence[IntegerAttr] | Sequence[FloatAttr],
) -> DenseIntOrFPElementsAttr:
# zero rank type should only hold 1 value
if not type.get_shape() and len(data) != 1:
Expand Down Expand Up @@ -2228,7 +2226,7 @@ def tensor_from_list(
| Sequence[float]
| Sequence[IntegerAttr[IndexType]]
| Sequence[IntegerAttr[IntegerType]]
| Sequence[AnyFloatAttr]
| Sequence[FloatAttr]
),
data_type: IntegerType | IndexType | AnyFloat,
shape: Sequence[int],
Expand All @@ -2248,7 +2246,7 @@ def get_values(self) -> Sequence[int] | Sequence[float]:
"""
return self.get_element_type().unpack(self.data.data, len(self))

def iter_attrs(self) -> Iterator[IntegerAttr] | Iterator[AnyFloatAttr]:
def iter_attrs(self) -> Iterator[IntegerAttr] | Iterator[FloatAttr]:
"""
Return an iterator over all elements of the dense attribute in their relevant
attribute representation (IntegerAttr / FloatAttr)
Expand All @@ -2258,7 +2256,7 @@ def iter_attrs(self) -> Iterator[IntegerAttr] | Iterator[AnyFloatAttr]:
else:
return FloatAttr.iter_unpack(eltype, self.data.data)

def get_attrs(self) -> Sequence[IntegerAttr] | Sequence[AnyFloatAttr]:
def get_attrs(self) -> Sequence[IntegerAttr] | Sequence[FloatAttr]:
"""
Return all elements of the dense attribute in their relevant
attribute representation (IntegerAttr / FloatAttr)
Expand Down
7 changes: 3 additions & 4 deletions xdsl/dialects/csl/csl.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@
from xdsl.dialects import builtin
from xdsl.dialects.builtin import (
AffineMapAttr,
AnyFloatAttr,
AnyFloatAttrConstr,
ArrayAttr,
BoolAttr,
ContainerType,
DictionaryAttr,
Float16Type,
Float32Type,
FloatAttr,
FunctionType,
IntegerAttr,
IntegerType,
Expand Down Expand Up @@ -417,8 +416,8 @@ def get_element_type(self) -> TypeAttribute:

QueueIdAttr: TypeAlias = IntegerAttr[Annotated[IntegerType, IntegerType(3)]]

ParamAttr: TypeAlias = AnyFloatAttr | IntegerAttr
ParamAttrConstr = AnyFloatAttrConstr | IntegerAttr.constr()
ParamAttr: TypeAlias = FloatAttr | IntegerAttr
ParamAttrConstr = FloatAttr.constr() | IntegerAttr.constr()


@irdl_op_definition
Expand Down
4 changes: 2 additions & 2 deletions xdsl/interpreters/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import cast

from xdsl.dialects import arith
from xdsl.dialects.builtin import AnyFloatAttr, IntegerAttr
from xdsl.dialects.builtin import FloatAttr, IntegerAttr
from xdsl.interpreter import (
Interpreter,
InterpreterFunctions,
Expand All @@ -23,7 +23,7 @@ def run_constant(
) -> PythonValues:
value = op.value
interpreter.interpreter_assert(
isattr(op.value, base(IntegerAttr) | base(AnyFloatAttr)),
isattr(op.value, base(IntegerAttr) | base(FloatAttr)),
f"arith.constant not implemented for {type(op.value)}",
)
value = cast(IntegerAttr, op.value)
Expand Down
10 changes: 5 additions & 5 deletions xdsl/interpreters/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

from xdsl.dialects import builtin
from xdsl.dialects.builtin import (
AnyFloatAttr,
Float32Type,
Float64Type,
FloatAttr,
IntegerAttr,
IntegerType,
PackableType,
Expand Down Expand Up @@ -59,16 +59,16 @@ def run_cast(
def float64_attr_value(
self, interpreter: Interpreter, attr: Attribute, attr_type: Float64Type
) -> float:
interpreter.interpreter_assert(isa(attr, AnyFloatAttr))
attr = cast(AnyFloatAttr, attr)
interpreter.interpreter_assert(isa(attr, FloatAttr))
attr = cast(FloatAttr, attr)
return attr.value.data

@impl_attr(Float32Type)
def float32_attr_value(
self, interpreter: Interpreter, attr: Attribute, attr_type: Float32Type
) -> float:
interpreter.interpreter_assert(isa(attr, AnyFloatAttr))
attr = cast(AnyFloatAttr, attr)
interpreter.interpreter_assert(isa(attr, FloatAttr))
attr = cast(FloatAttr, attr)
return attr.value.data

@impl_attr(IntegerType)
Expand Down
3 changes: 1 addition & 2 deletions xdsl/parser/attribute_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
AnyArrayAttr,
AnyDenseElement,
AnyFloat,
AnyFloatAttr,
AnyFloatConstr,
AnyTensorType,
AnyUnrankedTensorType,
Expand Down Expand Up @@ -1202,7 +1201,7 @@ def parse_optional_location(self) -> LocationAttr | None:

def parse_optional_builtin_int_or_float_attr(
self,
) -> IntegerAttr | AnyFloatAttr | None:
) -> IntegerAttr | FloatAttr | None:
bool = self.try_parse_builtin_boolean_attr()
if bool is not None:
return bool
Expand Down
4 changes: 2 additions & 2 deletions xdsl/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
AffineMapAttr,
AffineSetAttr,
AnyFloat,
AnyFloatAttr,
AnyUnrankedMemRefType,
AnyUnrankedTensorType,
AnyVectorType,
Expand All @@ -28,6 +27,7 @@
Float64Type,
Float80Type,
Float128Type,
FloatAttr,
FunctionType,
IndexType,
IntAttr,
Expand Down Expand Up @@ -333,7 +333,7 @@ def print_bytes_literal(self, bytestring: bytes):
self.print_string(chr(byte))
self.print_string('"')

def print_float_attr(self, attribute: AnyFloatAttr):
def print_float_attr(self, attribute: FloatAttr):
self.print_float(attribute.value.data, attribute.type)

def print_float(self, value: float, type: AnyFloat):
Expand Down
12 changes: 6 additions & 6 deletions xdsl/transforms/canonicalization_patterns/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def match_and_rewrite(

def _fold_const_operation(
op_t: type[arith.FloatingPointLikeBinaryOperation],
lhs: builtin.AnyFloatAttr,
rhs: builtin.AnyFloatAttr,
lhs: builtin.FloatAttr,
rhs: builtin.FloatAttr,
) -> arith.ConstantOp | None:
match op_t:
case arith.AddfOp:
Expand Down Expand Up @@ -88,8 +88,8 @@ def match_and_rewrite(
if (
isinstance(op.lhs.owner, arith.ConstantOp)
and isinstance(op.rhs.owner, arith.ConstantOp)
and isa(l := op.lhs.owner.value, builtin.AnyFloatAttr)
and isa(r := op.rhs.owner.value, builtin.AnyFloatAttr)
and isa(l := op.lhs.owner.value, builtin.FloatAttr)
and isa(r := op.rhs.owner.value, builtin.FloatAttr)
and (cnst := _fold_const_operation(type(op), l, r))
):
rewriter.replace_matched_op(cnst)
Expand Down Expand Up @@ -126,8 +126,8 @@ def match_and_rewrite(
or u.fastmath is None
or arith.FastMathFlag.REASSOC not in op.fastmath.flags
or arith.FastMathFlag.REASSOC not in u.fastmath.flags
or not isa(c1 := const1.value, builtin.AnyFloatAttr)
or not isa(c2 := const2.value, builtin.AnyFloatAttr)
or not isa(c1 := const1.value, builtin.FloatAttr)
or not isa(c2 := const2.value, builtin.FloatAttr)
):
return

Expand Down
4 changes: 2 additions & 2 deletions xdsl/transforms/convert_stencil_to_csl_stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from xdsl.context import MLContext
from xdsl.dialects import arith, builtin, memref, stencil, tensor, varith
from xdsl.dialects.builtin import (
AnyFloatAttr,
AnyTensorType,
DenseIntOrFPElementsAttr,
FloatAttr,
IndexType,
IntegerAttr,
IntegerType,
Expand Down Expand Up @@ -578,7 +578,7 @@ def match_and_rewrite(self, op: csl_stencil.AccessOp, rewriter: PatternRewriter,
return

val = dense.get_attrs()[0]
assert isattr(val, AnyFloatAttr)
assert isattr(val, FloatAttr)
apply.add_coeff(op.offset, val)
rewriter.replace_op(mulf, [], new_results=[op.result])

Expand Down
4 changes: 2 additions & 2 deletions xdsl/transforms/linalg_to_csl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from xdsl.context import MLContext
from xdsl.dialects import arith, linalg
from xdsl.dialects.builtin import (
AnyFloatAttr,
DenseIntOrFPElementsAttr,
Float16Type,
Float32Type,
FloatAttr,
IntegerAttr,
MemRefType,
ModuleOp,
Expand Down Expand Up @@ -39,7 +39,7 @@ def match_op_for_precision(
raise ValueError(f"Unsupported element type {prec}")


def get_scalar_const(op: SSAValue) -> AnyFloatAttr | IntegerAttr | None:
def get_scalar_const(op: SSAValue) -> FloatAttr | IntegerAttr | None:
"""Returns the value of a scalar arith.constant, or None if not a constant or not scalar)."""
if (
isinstance(op, OpResult)
Expand Down
3 changes: 1 addition & 2 deletions xdsl/transforms/lower_csl_stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from xdsl.dialects import arith, func, memref, stencil
from xdsl.dialects.builtin import (
AffineMapAttr,
AnyFloatAttr,
DenseIntOrFPElementsAttr,
Float16Type,
Float32Type,
Expand Down Expand Up @@ -265,7 +264,7 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter,
pattern = wrapper.get_param_value("pattern").value.data
neighbours = pattern - 1
empty = [FloatAttr(f, elem_t) for f in [0] + neighbours * [1]]
cmap: dict[csl.Direction, list[AnyFloatAttr]] = {
cmap: dict[csl.Direction, list[FloatAttr]] = {
csl.Direction.NORTH: empty,
csl.Direction.SOUTH: empty.copy(),
csl.Direction.EAST: empty.copy(),
Expand Down

0 comments on commit 4504a60

Please sign in to comment.