diff --git a/xdsl/dialects/memref_stream.py b/xdsl/dialects/memref_stream.py index 2872d469b0..0f8caaedc0 100644 --- a/xdsl/dialects/memref_stream.py +++ b/xdsl/dialects/memref_stream.py @@ -10,9 +10,9 @@ from collections.abc import Iterator, Sequence from enum import auto from itertools import product -from typing import Any, ClassVar, Generic, TypeVar, cast +from typing import Any, ClassVar, Generic, cast -from typing_extensions import Self +from typing_extensions import Self, TypeVar from xdsl.dialects import memref from xdsl.dialects.builtin import ( @@ -39,7 +39,6 @@ from xdsl.irdl import ( AnyAttr, AttrSizedOperandSegments, - BaseAttr, GenericAttrConstraint, IRDLOperation, ParamAttrConstraint, @@ -66,8 +65,9 @@ from xdsl.utils.hints import isa from xdsl.utils.str_enum import StrEnum -_StreamTypeElement = TypeVar("_StreamTypeElement", bound=Attribute, covariant=True) -_StreamTypeElementConstrT = TypeVar("_StreamTypeElementConstrT", bound=Attribute) +_StreamTypeElement = TypeVar( + "_StreamTypeElement", bound=Attribute, covariant=True, default=Attribute +) @irdl_attr_definition @@ -87,20 +87,16 @@ def get_element_type(self) -> _StreamTypeElement: def __init__(self, element_type: _StreamTypeElement): super().__init__([element_type]) - @staticmethod + @classmethod def constr( - element_type: GenericAttrConstraint[_StreamTypeElementConstrT], - ) -> ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]]: - return ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]]( + cls, + element_type: GenericAttrConstraint[_StreamTypeElement] = AnyAttr(), + ) -> ParamAttrConstraint[ReadableStreamType[_StreamTypeElement]]: + return ParamAttrConstraint[ReadableStreamType[_StreamTypeElement]]( ReadableStreamType, (element_type,) ) -AnyReadableStreamTypeConstr = BaseAttr[ReadableStreamType[Attribute]]( - ReadableStreamType -) - - @irdl_attr_definition class WritableStreamType( Generic[_StreamTypeElement], @@ -118,20 +114,16 @@ def get_element_type(self) -> _StreamTypeElement: def __init__(self, element_type: _StreamTypeElement): super().__init__([element_type]) - @staticmethod + @classmethod def constr( - element_type: GenericAttrConstraint[_StreamTypeElementConstrT], - ) -> ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]]: - return ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]]( + cls, + element_type: GenericAttrConstraint[_StreamTypeElement] = AnyAttr(), + ) -> ParamAttrConstraint[WritableStreamType[_StreamTypeElement]]: + return ParamAttrConstraint[WritableStreamType[_StreamTypeElement]]( WritableStreamType, (element_type,) ) -AnyWritableStreamTypeConstr = BaseAttr[WritableStreamType[Attribute]]( - WritableStreamType -) - - class IteratorType(StrEnum): "Iterator type for memref_stream Attribute" @@ -471,7 +463,7 @@ class GenericOp(IRDLOperation): Pointers to memory buffers or streams to be operated on. The corresponding stride pattern defines the order in which the elements of the input buffers will be read. """ - outputs = var_operand_def(AnyMemRefTypeConstr | AnyWritableStreamTypeConstr) + outputs = var_operand_def(AnyMemRefTypeConstr | WritableStreamType.constr()) """ Pointers to memory buffers or streams to be operated on. The corresponding stride pattern defines the order in which the elements of the input buffers will be written diff --git a/xdsl/dialects/snitch.py b/xdsl/dialects/snitch.py index 37309e0f99..5a89c793a7 100644 --- a/xdsl/dialects/snitch.py +++ b/xdsl/dialects/snitch.py @@ -13,7 +13,9 @@ from abc import ABC from collections.abc import Sequence from dataclasses import dataclass -from typing import Generic, TypeVar +from typing import Generic + +from typing_extensions import TypeVar from xdsl.dialects.builtin import ContainerType, IntAttr from xdsl.dialects.riscv import IntRegisterType @@ -26,7 +28,7 @@ TypeAttribute, ) from xdsl.irdl import ( - BaseAttr, + AnyAttr, GenericAttrConstraint, IRDLOperation, ParamAttrConstraint, @@ -39,8 +41,9 @@ ) from xdsl.utils.exceptions import VerifyException -_StreamTypeElement = TypeVar("_StreamTypeElement", bound=Attribute, covariant=True) -_StreamTypeElementConstrT = TypeVar("_StreamTypeElementConstrT", bound=Attribute) +_StreamTypeElement = TypeVar( + "_StreamTypeElement", bound=Attribute, covariant=True, default=Attribute +) @irdl_attr_definition @@ -60,20 +63,16 @@ def get_element_type(self) -> _StreamTypeElement: def __init__(self, element_type: _StreamTypeElement): super().__init__([element_type]) - @staticmethod + @classmethod def constr( - element_type: GenericAttrConstraint[_StreamTypeElementConstrT], - ) -> ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]]: - return ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]]( + cls, + element_type: GenericAttrConstraint[_StreamTypeElement] = AnyAttr(), + ) -> ParamAttrConstraint[ReadableStreamType[_StreamTypeElement]]: + return ParamAttrConstraint[ReadableStreamType[_StreamTypeElement]]( ReadableStreamType, (element_type,) ) -AnyReadableStreamTypeConstr = BaseAttr[ReadableStreamType[Attribute]]( - ReadableStreamType -) - - @irdl_attr_definition class WritableStreamType( Generic[_StreamTypeElement], @@ -91,20 +90,16 @@ def get_element_type(self) -> _StreamTypeElement: def __init__(self, element_type: _StreamTypeElement): super().__init__([element_type]) - @staticmethod + @classmethod def constr( - element_type: GenericAttrConstraint[_StreamTypeElementConstrT], - ) -> ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]]: - return ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]]( + cls, + element_type: GenericAttrConstraint[_StreamTypeElement] = AnyAttr(), + ) -> ParamAttrConstraint[WritableStreamType[_StreamTypeElement]]: + return ParamAttrConstraint[WritableStreamType[_StreamTypeElement]]( WritableStreamType, (element_type,) ) -AnyWritableStreamTypeConstr = BaseAttr[WritableStreamType[Attribute]]( - WritableStreamType -) - - @dataclass(frozen=True) class SnitchResources: """ @@ -222,7 +217,7 @@ class SsrEnableOp(IRDLOperation): name = "snitch.ssr_enable" - streams = var_result_def(AnyReadableStreamTypeConstr | AnyWritableStreamTypeConstr) + streams = var_result_def(ReadableStreamType.constr() | WritableStreamType.constr()) def __init__(self, stream_types: Sequence[Attribute]): super().__init__(result_types=[stream_types])