Skip to content

Commit

Permalink
core: Add shorthand type for types that can be converted to IRDL
Browse files Browse the repository at this point in the history
There was a lot of inconsistencies in the typing of functions that are
using the type to constraint conversion. Having a shorthand makes this
more explicit.

stack-info: PR: #3836, branch: math-fehr/stack/2
  • Loading branch information
math-fehr committed Feb 6, 2025
1 parent 2a280b2 commit c6cc4a9
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 94 deletions.
32 changes: 30 additions & 2 deletions xdsl/irdl/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
from inspect import isclass
from types import FunctionType, GenericAlias, UnionType
from typing import (
TYPE_CHECKING,
Annotated,
Any,
Generic,
TypeAlias,
TypeVar,
Union,
cast,
Expand All @@ -22,6 +24,10 @@
get_type_hints,
)

if TYPE_CHECKING:
from typing_extensions import TypeForm


from xdsl.ir import (
Attribute,
AttributeInvT,
Expand Down Expand Up @@ -256,8 +262,30 @@ def irdl_attr_definition(cls: TypeAttributeInvT) -> TypeAttributeInvT:
)


IRDLGenericAttributeConstraint: TypeAlias = (
GenericAttrConstraint[AttributeInvT]
| Attribute
| type[AttributeInvT]
| "TypeForm[AttributeInvT]"
| ConstraintVar
| TypeVar
)
"""
Attribute constraints represented using the IRDL python frontend. Attribute constraints
can either be:
- An instance of `AttrConstraint` representing a constraint on an attribute.
- An instance of `Attribute` representing an equality constraint on an attribute.
- A type representing a specific attribute class.
- A TypeForm that can represent both unions and generic attributes.
- A `ConstraintVar` representing a constraint variable.
"""

IRDLAttributeConstraint = IRDLGenericAttributeConstraint[Attribute]
"""See `IRDLGenericAttributeConstraint`."""


def irdl_list_to_attr_constraint(
pyrdl_constraints: Sequence[Any],
pyrdl_constraints: Sequence[IRDLAttributeConstraint],
*,
allow_type_var: bool = False,
type_var_mapping: dict[TypeVar, AttrConstraint] | None = None,
Expand Down Expand Up @@ -299,7 +327,7 @@ def irdl_list_to_attr_constraint(


def irdl_to_attr_constraint(
irdl: Any,
irdl: IRDLAttributeConstraint,
*,
allow_type_var: bool = False,
type_var_mapping: dict[TypeVar, AttrConstraint] | None = None,
Expand Down
127 changes: 36 additions & 91 deletions xdsl/irdl/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
overload,
)

import typing_extensions
from typing_extensions import assert_never

from xdsl.ir import (
Expand All @@ -46,6 +45,8 @@
)

from .attributes import ( # noqa: TID251
IRDLAttributeConstraint,
IRDLGenericAttributeConstraint,
irdl_list_to_attr_constraint,
irdl_to_attr_constraint,
)
Expand All @@ -54,7 +55,6 @@
AttrConstraint,
ConstraintContext,
ConstraintVar,
GenericAttrConstraint,
GenericRangeConstraint,
RangeConstraint,
RangeOf,
Expand Down Expand Up @@ -521,25 +521,19 @@ def __init__(self, cls: type[_ClsT]):


class _RangeConstrainedOpDefField(Generic[_ClsT], _OpDefField[_ClsT]):
param: RangeConstraint | AttrConstraint | Attribute | type[Attribute] | TypeVar
param: RangeConstraint | IRDLAttributeConstraint

def __init__(
self,
cls: type[_ClsT],
param: RangeConstraint | AttrConstraint | Attribute | type[Attribute] | TypeVar,
self, cls: type[_ClsT], param: RangeConstraint | IRDLAttributeConstraint
):
super().__init__(cls)
self.param = param


class _ConstrainedOpDefField(Generic[_ClsT], _OpDefField[_ClsT]):
param: AttrConstraint | Attribute | type[Attribute] | TypeVar
param: IRDLAttributeConstraint

def __init__(
self,
cls: type[_ClsT],
param: AttrConstraint | Attribute | type[Attribute] | TypeVar,
):
def __init__(self, cls: type[_ClsT], param: IRDLAttributeConstraint):
super().__init__(cls)
self.param = param

Expand Down Expand Up @@ -568,7 +562,7 @@ class _AttrOrPropFieldDef(
def __init__(
self,
cls: type[AttrOrPropInvT],
param: AttrConstraint | Attribute | type[Attribute] | TypeVar,
param: IRDLAttributeConstraint,
ir_name: str | None = None,
default_value: Attribute | None = None,
):
Expand All @@ -586,24 +580,12 @@ class _PropertyFieldDef(_AttrOrPropFieldDef[PropertyDef]):


class _RegionFieldDef(_OpDefField[RegionDef]):
entry_args: (
GenericRangeConstraint[Attribute]
| AttrConstraint
| Attribute
| type[Attribute]
| TypeVar
)
entry_args: GenericRangeConstraint[Attribute] | IRDLAttributeConstraint

def __init__(
self,
cls: type[RegionDef],
entry_args: (
GenericRangeConstraint[Attribute]
| AttrConstraint
| Attribute
| type[Attribute]
| TypeVar
),
entry_args: GenericRangeConstraint[Attribute] | IRDLAttributeConstraint,
):
super().__init__(cls)
self.entry_args = entry_args
Expand All @@ -616,7 +598,7 @@ class _SuccessorFieldDef(_OpDefField[SuccessorDef]):


def result_def(
constraint: (AttrConstraint | Attribute | type[Attribute] | TypeVar) = Attribute,
constraint: IRDLAttributeConstraint = Attribute,
*,
default: None = None,
resolver: None = None,
Expand All @@ -629,9 +611,7 @@ def result_def(


def var_result_def(
constraint: (
RangeConstraint | AttrConstraint | Attribute | type[Attribute] | TypeVar
) = Attribute,
constraint: RangeConstraint | IRDLAttributeConstraint = Attribute,
*,
default: None = None,
resolver: None = None,
Expand All @@ -644,9 +624,7 @@ def var_result_def(


def opt_result_def(
constraint: (
RangeConstraint | AttrConstraint | Attribute | type[Attribute] | TypeVar
) = Attribute,
constraint: RangeConstraint | IRDLAttributeConstraint = Attribute,
*,
default: None = None,
resolver: None = None,
Expand All @@ -659,7 +637,7 @@ def opt_result_def(


def prop_def(
constraint: type[AttributeInvT] | TypeVar | GenericAttrConstraint[AttributeInvT],
constraint: IRDLGenericAttributeConstraint[AttributeInvT],
default_value: Attribute | None = None,
*,
prop_name: str | None = None,
Expand All @@ -676,7 +654,7 @@ def prop_def(

@overload
def opt_prop_def(
constraint: type[AttributeInvT] | TypeVar | GenericAttrConstraint[AttributeInvT],
constraint: IRDLGenericAttributeConstraint[AttributeInvT],
default_value: None = None,
*,
prop_name: str | None = None,
Expand All @@ -688,7 +666,7 @@ def opt_prop_def(

@overload
def opt_prop_def(
constraint: type[AttributeInvT] | TypeVar | GenericAttrConstraint[AttributeInvT],
constraint: IRDLGenericAttributeConstraint[AttributeInvT],
default_value: Attribute,
*,
prop_name: str | None = None,
Expand All @@ -699,7 +677,7 @@ def opt_prop_def(


def opt_prop_def(
constraint: type[AttributeInvT] | TypeVar | GenericAttrConstraint[AttributeInvT],
constraint: IRDLGenericAttributeConstraint[AttributeInvT],
default_value: Attribute | None = None,
*,
prop_name: str | None = None,
Expand All @@ -715,7 +693,7 @@ def opt_prop_def(


def attr_def(
constraint: (type[AttributeInvT] | TypeVar | GenericAttrConstraint[AttributeInvT]),
constraint: IRDLGenericAttributeConstraint[AttributeInvT],
default_value: Attribute | None = None,
*,
attr_name: str | None = None,
Expand All @@ -734,7 +712,7 @@ def attr_def(

@overload
def opt_attr_def(
constraint: type[AttributeInvT] | TypeVar | GenericAttrConstraint[AttributeInvT],
constraint: IRDLGenericAttributeConstraint[AttributeInvT],
default_value: None = None,
*,
attr_name: str | None = None,
Expand All @@ -746,7 +724,7 @@ def opt_attr_def(

@overload
def opt_attr_def(
constraint: type[AttributeInvT] | TypeVar | GenericAttrConstraint[AttributeInvT],
constraint: IRDLGenericAttributeConstraint[AttributeInvT],
default_value: Attribute,
*,
attr_name: str | None = None,
Expand All @@ -757,7 +735,7 @@ def opt_attr_def(


def opt_attr_def(
constraint: type[AttributeInvT] | TypeVar | AttrConstraint,
constraint: IRDLGenericAttributeConstraint[AttributeInvT],
default_value: Attribute | None = None,
*,
attr_name: str | None = None,
Expand All @@ -775,7 +753,7 @@ def opt_attr_def(


def operand_def(
constraint: (AttrConstraint | Attribute | type[Attribute] | TypeVar) = Attribute,
constraint: IRDLAttributeConstraint = Attribute,
*,
default: None = None,
resolver: None = None,
Expand All @@ -788,14 +766,7 @@ def operand_def(


def var_operand_def(
constraint: (
RangeConstraint
| AttrConstraint
| Attribute
| type[Attribute]
| TypeVar
| typing_extensions.TypeVar
) = Attribute,
constraint: RangeConstraint | IRDLAttributeConstraint = Attribute,
*,
default: None = None,
resolver: None = None,
Expand All @@ -808,9 +779,7 @@ def var_operand_def(


def opt_operand_def(
constraint: (
RangeConstraint | AttrConstraint | Attribute | type[Attribute] | TypeVar
) = Attribute,
constraint: RangeConstraint | IRDLAttributeConstraint = Attribute,
*,
default: None = None,
resolver: None = None,
Expand All @@ -825,13 +794,9 @@ def opt_operand_def(
def region_def(
single_block: Literal["single_block"] | None = None,
*,
entry_args: (
GenericRangeConstraint[Attribute]
| AttrConstraint
| Attribute
| type[Attribute]
| TypeVar
) = RangeOf(AnyAttr()),
entry_args: GenericRangeConstraint[Attribute] | IRDLAttributeConstraint = RangeOf(
AnyAttr()
),
default: None = None,
resolver: None = None,
init: Literal[False] = False,
Expand All @@ -846,13 +811,9 @@ def region_def(
def var_region_def(
single_block: Literal["single_block"] | None = None,
*,
entry_args: (
GenericRangeConstraint[Attribute]
| AttrConstraint
| Attribute
| type[Attribute]
| TypeVar
) = RangeOf(AnyAttr()),
entry_args: GenericRangeConstraint[Attribute] | IRDLAttributeConstraint = RangeOf(
AnyAttr()
),
default: None = None,
resolver: None = None,
init: Literal[False] = False,
Expand All @@ -867,13 +828,9 @@ def var_region_def(
def opt_region_def(
single_block: Literal["single_block"] | None = None,
*,
entry_args: (
GenericRangeConstraint[Attribute]
| AttrConstraint
| Attribute
| type[Attribute]
| TypeVar
) = RangeOf(AnyAttr()),
entry_args: GenericRangeConstraint[Attribute] | IRDLAttributeConstraint = RangeOf(
AnyAttr()
),
default: None = None,
resolver: None = None,
init: Literal[False] = False,
Expand Down Expand Up @@ -1118,13 +1075,7 @@ def wrong_field_exception(field_name: str) -> PyRDLOpDefinitionError:

# Get attribute constraints from a list of pyrdl constraints
def get_constraint(
pyrdl_constr: (
AttrConstraint
| Attribute
| type[Attribute]
| TypeVar
| ConstraintVar
),
pyrdl_constr: IRDLAttributeConstraint,
) -> AttrConstraint:
return irdl_list_to_attr_constraint(
(pyrdl_constr,),
Expand All @@ -1134,17 +1085,11 @@ def get_constraint(

# Get attribute constraints from a list of pyrdl constraints
def get_range_constraint(
pyrdl_constr: (
RangeConstraint
| AttrConstraint
| Attribute
| type[Attribute]
| TypeVar
| ConstraintVar
),
pyrdl_constr: RangeConstraint | IRDLAttributeConstraint,
) -> RangeConstraint:
if isinstance(pyrdl_constr, GenericRangeConstraint):
return pyrdl_constr
# Pyright does not know the type of the generic range constraint
return cast(RangeConstraint, pyrdl_constr)
return RangeOf(get_constraint(pyrdl_constr))

field_names.add(field_name)
Expand Down
4 changes: 3 additions & 1 deletion xdsl/utils/hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ def isa(arg: Any, hint: "TypeForm[_T]") -> TypeGuard[_T]:
from xdsl.irdl import GenericData, irdl_to_attr_constraint

if (origin is not None) and issubclass(origin, GenericData | ParametrizedAttribute):
constraint = irdl_to_attr_constraint(hint)
constraint = irdl_to_attr_constraint(
hint # pyright: ignore[reportArgumentType]
)
try:
constraint.verify(arg, ConstraintContext())
return True
Expand Down

0 comments on commit c6cc4a9

Please sign in to comment.