diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py index efefb47015..01f98a479f 100644 --- a/xdsl/dialects/builtin.py +++ b/xdsl/dialects/builtin.py @@ -1031,23 +1031,28 @@ def __init__(self, types: list[Attribute] | ArrayAttr[Attribute]) -> None: super().__init__([types]) +_VectorTypeElement = TypeVar( + "_VectorTypeElement", bound=Attribute, covariant=True, default=Attribute +) + + @irdl_attr_definition class VectorType( - Generic[AttributeCovT], + Generic[_VectorTypeElement], ParametrizedAttribute, TypeAttribute, ShapedType, - ContainerType[AttributeCovT], + ContainerType[_VectorTypeElement], ): name = "vector" shape: ParameterDef[ArrayAttr[IntAttr]] - element_type: ParameterDef[AttributeCovT] + element_type: ParameterDef[_VectorTypeElement] num_scalable_dims: ParameterDef[IntAttr] def __init__( self, - element_type: AttributeCovT, + element_type: _VectorTypeElement, shape: Iterable[int | IntAttr], num_scalable_dims: int | IntAttr = 0, ) -> None: @@ -1067,7 +1072,7 @@ def get_num_scalable_dims(self) -> int: def get_shape(self) -> tuple[int, ...]: return tuple(i.data for i in self.shape) - def get_element_type(self) -> AttributeCovT: + def get_element_type(self) -> _VectorTypeElement: return self.element_type def verify(self): @@ -1086,24 +1091,28 @@ def verify(self): AnyVectorType: TypeAlias = VectorType[Attribute] +_TensorTypeElement = TypeVar( + "_TensorTypeElement", bound=Attribute, covariant=True, default=Attribute +) + @irdl_attr_definition class TensorType( - Generic[AttributeCovT], + Generic[_TensorTypeElement], ParametrizedAttribute, TypeAttribute, ShapedType, - ContainerType[AttributeCovT], + ContainerType[_TensorTypeElement], ): name = "tensor" shape: ParameterDef[ArrayAttr[IntAttr]] - element_type: ParameterDef[AttributeCovT] + element_type: ParameterDef[_TensorTypeElement] encoding: ParameterDef[Attribute] def __init__( self, - element_type: AttributeCovT, + element_type: _TensorTypeElement, shape: Iterable[int | IntAttr], encoding: Attribute = NoneAttr(), ): @@ -1118,7 +1127,7 @@ def get_num_dims(self) -> int: def get_shape(self) -> tuple[int, ...]: return tuple(i.data for i in self.shape.data) - def get_element_type(self) -> AttributeCovT: + def get_element_type(self) -> _TensorTypeElement: return self.element_type @@ -1161,9 +1170,9 @@ class ContainerOf( def __init__( self, - elem_constr: AttributeCovT - | type[AttributeCovT] - | GenericAttrConstraint[AttributeCovT], + elem_constr: ( + AttributeCovT | type[AttributeCovT] | GenericAttrConstraint[AttributeCovT] + ), ) -> None: object.__setattr__(self, "elem_constr", attr_constr_coercion(elem_constr)) @@ -1822,7 +1831,7 @@ def print(self, printer: Printer) -> None: "_MemRefTypeElement", bound=Attribute, covariant=True, default=Attribute ) _UnrankedMemrefTypeElems = TypeVar( - "_UnrankedMemrefTypeElems", bound=Attribute, covariant=True + "_UnrankedMemrefTypeElems", bound=Attribute, covariant=True, default=Attribute ) _UnrankedMemrefTypeElemsInit = TypeVar("_UnrankedMemrefTypeElemsInit", bound=Attribute) @@ -1985,9 +1994,9 @@ class TensorOrMemrefOf( def __init__( self, - elem_constr: AttributeCovT - | type[AttributeCovT] - | GenericAttrConstraint[AttributeCovT], + elem_constr: ( + AttributeCovT | type[AttributeCovT] | GenericAttrConstraint[AttributeCovT] + ), ) -> None: object.__setattr__(self, "elem_constr", attr_constr_coercion(elem_constr)) diff --git a/xdsl/dialects/mpi.py b/xdsl/dialects/mpi.py index d06299d5bc..7cadd94d59 100644 --- a/xdsl/dialects/mpi.py +++ b/xdsl/dialects/mpi.py @@ -110,7 +110,7 @@ class DataType(ParametrizedAttribute, TypeAttribute): VectorWrappable = RequestType | StatusType | DataType VectorWrappableConstr = base(RequestType) | base(StatusType) | base(DataType) -_VectorT = TypeVar("_VectorT", bound=VectorWrappable) +_VectorT = TypeVar("_VectorT", bound=VectorWrappable, default=VectorWrappable) @irdl_attr_definition diff --git a/xdsl/dialects/stencil.py b/xdsl/dialects/stencil.py index 77f1a3b61d..bb25a655d9 100644 --- a/xdsl/dialects/stencil.py +++ b/xdsl/dialects/stencil.py @@ -370,10 +370,15 @@ def constr( return ParamAttrConstraint(cls, (bounds, element_type)) +_StencilTypeElement = TypeVar( + "_StencilTypeElement", bound=Attribute, covariant=True, default=Attribute +) + + @irdl_attr_definition class FieldType( - Generic[_FieldTypeElement], - StencilType[_FieldTypeElement], + Generic[_StencilTypeElement], + StencilType[_StencilTypeElement], ParametrizedAttribute, TypeAttribute, ): @@ -387,10 +392,15 @@ class FieldType( name = "stencil.field" +_TempTypeElement = TypeVar( + "_TempTypeElement", bound=Attribute, covariant=True, default=Attribute +) + + @irdl_attr_definition class TempType( - Generic[_FieldTypeElement], - StencilType[_FieldTypeElement], + Generic[_TempTypeElement], + StencilType[_TempTypeElement], ParametrizedAttribute, TypeAttribute, ):