Skip to content

Commit

Permalink
dialects: Add more defaults to TypeVar
Browse files Browse the repository at this point in the history
stack-info: PR: #3835, branch: math-fehr/stack/1
  • Loading branch information
math-fehr committed Feb 5, 2025
1 parent c503d9d commit 2c7e966
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 22 deletions.
43 changes: 26 additions & 17 deletions xdsl/dialects/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,23 +1031,28 @@ def __init__(self, types: list[Attribute] | ArrayAttr[Attribute]) -> None:
super().__init__([types])


_VectorTypeElement = TypeVar(
"_VectorTypeElement", bound=Attribute, covariant=True, default=Attribute
)


@irdl_attr_definition
class VectorType(
Generic[AttributeCovT],
Generic[_VectorTypeElement],
ParametrizedAttribute,
TypeAttribute,
ShapedType,
ContainerType[AttributeCovT],
ContainerType[_VectorTypeElement],
):
name = "vector"

shape: ParameterDef[ArrayAttr[IntAttr]]
element_type: ParameterDef[AttributeCovT]
element_type: ParameterDef[_VectorTypeElement]
num_scalable_dims: ParameterDef[IntAttr]

def __init__(
self,
element_type: AttributeCovT,
element_type: _VectorTypeElement,
shape: Iterable[int | IntAttr],
num_scalable_dims: int | IntAttr = 0,
) -> None:
Expand All @@ -1067,7 +1072,7 @@ def get_num_scalable_dims(self) -> int:
def get_shape(self) -> tuple[int, ...]:
return tuple(i.data for i in self.shape)

def get_element_type(self) -> AttributeCovT:
def get_element_type(self) -> _VectorTypeElement:
return self.element_type

def verify(self):
Expand All @@ -1086,24 +1091,28 @@ def verify(self):

AnyVectorType: TypeAlias = VectorType[Attribute]

_TensorTypeElement = TypeVar(
"_TensorTypeElement", bound=Attribute, covariant=True, default=Attribute
)


@irdl_attr_definition
class TensorType(
Generic[AttributeCovT],
Generic[_TensorTypeElement],
ParametrizedAttribute,
TypeAttribute,
ShapedType,
ContainerType[AttributeCovT],
ContainerType[_TensorTypeElement],
):
name = "tensor"

shape: ParameterDef[ArrayAttr[IntAttr]]
element_type: ParameterDef[AttributeCovT]
element_type: ParameterDef[_TensorTypeElement]
encoding: ParameterDef[Attribute]

def __init__(
self,
element_type: AttributeCovT,
element_type: _TensorTypeElement,
shape: Iterable[int | IntAttr],
encoding: Attribute = NoneAttr(),
):
Expand All @@ -1118,7 +1127,7 @@ def get_num_dims(self) -> int:
def get_shape(self) -> tuple[int, ...]:
return tuple(i.data for i in self.shape.data)

def get_element_type(self) -> AttributeCovT:
def get_element_type(self) -> _TensorTypeElement:
return self.element_type


Expand Down Expand Up @@ -1161,9 +1170,9 @@ class ContainerOf(

def __init__(
self,
elem_constr: AttributeCovT
| type[AttributeCovT]
| GenericAttrConstraint[AttributeCovT],
elem_constr: (
AttributeCovT | type[AttributeCovT] | GenericAttrConstraint[AttributeCovT]
),
) -> None:
object.__setattr__(self, "elem_constr", attr_constr_coercion(elem_constr))

Expand Down Expand Up @@ -1822,7 +1831,7 @@ def print(self, printer: Printer) -> None:
"_MemRefTypeElement", bound=Attribute, covariant=True, default=Attribute
)
_UnrankedMemrefTypeElems = TypeVar(
"_UnrankedMemrefTypeElems", bound=Attribute, covariant=True
"_UnrankedMemrefTypeElems", bound=Attribute, covariant=True, default=Attribute
)
_UnrankedMemrefTypeElemsInit = TypeVar("_UnrankedMemrefTypeElemsInit", bound=Attribute)

Expand Down Expand Up @@ -1985,9 +1994,9 @@ class TensorOrMemrefOf(

def __init__(
self,
elem_constr: AttributeCovT
| type[AttributeCovT]
| GenericAttrConstraint[AttributeCovT],
elem_constr: (
AttributeCovT | type[AttributeCovT] | GenericAttrConstraint[AttributeCovT]
),
) -> None:
object.__setattr__(self, "elem_constr", attr_constr_coercion(elem_constr))

Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class DataType(ParametrizedAttribute, TypeAttribute):

VectorWrappable = RequestType | StatusType | DataType
VectorWrappableConstr = base(RequestType) | base(StatusType) | base(DataType)
_VectorT = TypeVar("_VectorT", bound=VectorWrappable)
_VectorT = TypeVar("_VectorT", bound=VectorWrappable, default=VectorWrappable)

Check failure on line 113 in xdsl/dialects/mpi.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Type variable default types require Python 3.13 or newer (reportGeneralTypeIssues)

Check failure on line 113 in xdsl/dialects/mpi.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Type variable default types require Python 3.13 or newer (reportGeneralTypeIssues)

Check failure on line 113 in xdsl/dialects/mpi.py

View workflow job for this annotation

GitHub Actions / build (3.12)

Type variable default types require Python 3.13 or newer (reportGeneralTypeIssues)


@irdl_attr_definition
Expand Down
18 changes: 14 additions & 4 deletions xdsl/dialects/stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,10 +370,15 @@ def constr(
return ParamAttrConstraint(cls, (bounds, element_type))


_StencilTypeElement = TypeVar(
"_StencilTypeElement", bound=Attribute, covariant=True, default=Attribute

Check failure on line 374 in xdsl/dialects/stencil.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Type variable default types require Python 3.13 or newer (reportGeneralTypeIssues)

Check failure on line 374 in xdsl/dialects/stencil.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Type variable default types require Python 3.13 or newer (reportGeneralTypeIssues)

Check failure on line 374 in xdsl/dialects/stencil.py

View workflow job for this annotation

GitHub Actions / build (3.12)

Type variable default types require Python 3.13 or newer (reportGeneralTypeIssues)
)


@irdl_attr_definition
class FieldType(
Generic[_FieldTypeElement],
StencilType[_FieldTypeElement],
Generic[_StencilTypeElement],
StencilType[_StencilTypeElement],
ParametrizedAttribute,
TypeAttribute,
):
Expand All @@ -387,10 +392,15 @@ class FieldType(
name = "stencil.field"


_TempTypeElement = TypeVar(
"_TempTypeElement", bound=Attribute, covariant=True, default=Attribute

Check failure on line 396 in xdsl/dialects/stencil.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Type variable default types require Python 3.13 or newer (reportGeneralTypeIssues)

Check failure on line 396 in xdsl/dialects/stencil.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Type variable default types require Python 3.13 or newer (reportGeneralTypeIssues)

Check failure on line 396 in xdsl/dialects/stencil.py

View workflow job for this annotation

GitHub Actions / build (3.12)

Type variable default types require Python 3.13 or newer (reportGeneralTypeIssues)
)


@irdl_attr_definition
class TempType(
Generic[_FieldTypeElement],
StencilType[_FieldTypeElement],
Generic[_TempTypeElement],
StencilType[_TempTypeElement],
ParametrizedAttribute,
TypeAttribute,
):
Expand Down

0 comments on commit 2c7e966

Please sign in to comment.