From 0812db668dab19e7b25fa4a789e8e6f95137d6d1 Mon Sep 17 00:00:00 2001 From: KotlinIsland Date: Sun, 18 Aug 2024 14:40:26 +1000 Subject: [PATCH] wip --- .idea/basedtyping.iml | 2 +- {basedtyping => src/basedtyping}/__init__.py | 6 +- {basedtyping => src/basedtyping}/py.typed | 0 .../basedtyping}/runtime_only.py | 0 .../basedtyping}/transformer.py | 59 ++++++++++++++----- .../basedtyping}/typetime_only.py | 0 tests/test_transformer.py | 9 +-- .../test_typetime_only/test_typetime_only.py | 2 +- 8 files changed, 57 insertions(+), 21 deletions(-) rename {basedtyping => src/basedtyping}/__init__.py (99%) rename {basedtyping => src/basedtyping}/py.typed (100%) rename {basedtyping => src/basedtyping}/runtime_only.py (100%) rename {basedtyping => src/basedtyping}/transformer.py (79%) rename {basedtyping => src/basedtyping}/typetime_only.py (100%) diff --git a/.idea/basedtyping.iml b/.idea/basedtyping.iml index 7ad5deb..e83e1dc 100644 --- a/.idea/basedtyping.iml +++ b/.idea/basedtyping.iml @@ -3,8 +3,8 @@ - + diff --git a/basedtyping/__init__.py b/src/basedtyping/__init__.py similarity index 99% rename from basedtyping/__init__.py rename to src/basedtyping/__init__.py index b871c15..20f38c8 100644 --- a/basedtyping/__init__.py +++ b/src/basedtyping/__init__.py @@ -32,6 +32,7 @@ import typing_extensions from typing_extensions import Never, ParamSpec, Self, TypeAlias, TypeGuard, TypeVarTuple +from basedtyping import transformer from basedtyping.runtime_only import OldUnionType if not TYPE_CHECKING: @@ -606,6 +607,9 @@ class ForwardRef(typing.ForwardRef, _root=True): # type: ignore[call-arg,misc] if the original syntax is not supported in the current Python version. """ + # older typing.ForwardRef doesn't have this + __slots__ = ["__forward_module__", "__forward_is_class__"] + def __init__(self, arg, *, is_argument=True, module=None, is_class=False): if not isinstance(arg, str): raise TypeError(f"Forward reference must be a string -- got {arg!r}") @@ -641,4 +645,4 @@ def _evaluate( localns: dict[str, Any] | None, recursive_guard: frozenset[str] | None = None, ) -> Any: - return typing.t_eval_direct(self, globalns, localns) + return transformer._eval_direct(self, globalns, localns) diff --git a/basedtyping/py.typed b/src/basedtyping/py.typed similarity index 100% rename from basedtyping/py.typed rename to src/basedtyping/py.typed diff --git a/basedtyping/runtime_only.py b/src/basedtyping/runtime_only.py similarity index 100% rename from basedtyping/runtime_only.py rename to src/basedtyping/runtime_only.py diff --git a/basedtyping/transformer.py b/src/basedtyping/transformer.py similarity index 79% rename from basedtyping/transformer.py rename to src/basedtyping/transformer.py index da867a9..b6b0144 100644 --- a/basedtyping/transformer.py +++ b/src/basedtyping/transformer.py @@ -7,8 +7,11 @@ from contextlib import contextmanager from enum import Enum from typing import Any -import basedtyping +import typing_extensions + +import basedtyping +from basedtyping import ForwardRef class CringeTransformer(ast.NodeTransformer): """ @@ -37,7 +40,8 @@ def __init__( assert globalns is not None localns = globalns - self.typing_name = f"typing_extensions_{uuid.uuid4().hex}" + self.typing_name = f"typing_extensions" + # self.typing_name = f"typing_extensions_{uuid.uuid4().hex}" self.basedtyping_name = f"basedtyping_{uuid.uuid4().hex}" self.globalns = globalns import typing_extensions @@ -53,16 +57,22 @@ def eval_type( ) -> object: if not isinstance(node, ast.Expression): node = ast.copy_location(ast.Expression(node), node) - ref = typing.ForwardRef(ast.dump(node)) + ref = ForwardRef(ast.dump(node)) if original_ref: for attr in ("is_argument", " is_class", "module"): attr = f"__forward_{attr}__" if hasattr(original_ref, attr): - setattr(ref, attr, getattr(original_ref, attr)) - ref.__forward_code__ = compile(node, "", "eval") + try: + setattr(ref, attr, getattr(original_ref, attr)) + except AttributeError: + pass # older ForwardRefs don't have + try: + ref.__forward_code__ = compile(node, "", "eval") + except Exception as err: + ... try: return typing._eval_type(ref, self.globalns, self.localns) - except TypeError: + except TypeError as err: return None def _typing(self, attr: str): @@ -81,7 +91,7 @@ def _literal(self, value: ast.Constant | ast.Name | ast.Attribute): return self.subscript(self._typing("Literal"), value) def subscript(self, value, slice): - result = ast.Subscript(value=value, slice=slice, ctx=ast.Load()) + result = ast.Subscript(value=value, slice=ast.Index(slice), ctx=ast.Load()) return ast.fix_missing_locations(result) _implicit_tuple = False @@ -96,6 +106,10 @@ def implicit_tuple(self): self._implicit_tuple = implicit_tuple def visit_Subscript(self, node: ast.Subscript) -> ast.AST: + node_type = self.eval_type(node.value) + if node_type is typing_extensions.Annotated: + node.slice.value.elts[0] = self.visit(node.slice.value.elts[0]) + return node with self.implicit_tuple(): node = self.generic_visit(node) # TODO: FunctionType -> Callable @@ -104,12 +118,12 @@ def visit_Subscript(self, node: ast.Subscript) -> ast.AST: node = self.subscript(self._typing("Callable"), node.slice) return node - def visit_Attribute(self, node) -> ast.Name: - node = self.generic_visit(node) - node_type = self.eval_type(node) - if isinstance(node_type, Enum): - node = self._literal(node) - return node + # def visit_Attribute(self, node) -> ast.Name: + # node = self.generic_visit(node) + # node_type = self.eval_type(node) + # if isinstance(node_type, Enum): + # node = self._literal(node) + # return node def visit_Name(self, node) -> ast.Name: node = self.generic_visit(node) @@ -120,7 +134,7 @@ def visit_Name(self, node) -> ast.Name: def visit_Constant(self, node: ast.Constant) -> ast.AST: node = self.generic_visit(node) - if isinstance(node.value, int | bool) or ( + if isinstance(node.value, int) or ( self.string_literals and isinstance(node.value, str) ): node = self._literal(node) @@ -173,6 +187,23 @@ def _eval_direct( return eval_type_based(value, globalns, localns, string_literals=False) +def crifigy_type( + value: str, + globalns: typing.Mapping[str, object] | None = None, + localns: typing.Mapping[str, object] | None = None, + *, + string_literals: bool, +) -> object: + try: + tree = ast.parse(value, mode="eval") + except SyntaxError: + tree = ast.parse(value.removeprefix("def").lstrip(), mode="func_type") + + transformer = CringeTransformer(globalns, localns, string_literals=string_literals) + tree = transformer.visit(tree) + return ast.unparse(tree) + + def eval_type_based( value: object, globalns: typing.Mapping[str, object] | None = None, diff --git a/basedtyping/typetime_only.py b/src/basedtyping/typetime_only.py similarity index 100% rename from basedtyping/typetime_only.py rename to src/basedtyping/typetime_only.py diff --git a/tests/test_transformer.py b/tests/test_transformer.py index a3c6117..6c9cafe 100644 --- a/tests/test_transformer.py +++ b/tests/test_transformer.py @@ -1,9 +1,8 @@ from enum import Enum from typing import Tuple -from transformer import eval_type_based -from typing_extensions import Callable, Literal, TypeIs, Union, TypeGuard -from types import FunctionType +from basedtyping.transformer import eval_type_based +from typing_extensions import Callable, Literal, TypeIs, Union, TypeGuard, Annotated from basedtyping import ForwardRef, Intersection # ruff: noqa: PYI030, PYI030 @@ -30,7 +29,6 @@ def test_literal_str(): validate("'int'", Literal["int"], string_literals=True) validate("Literal['int']", Literal["int"], string_literals=True) - class E(Enum): a = 1 b = 2 @@ -71,3 +69,6 @@ def test_intersection(): def test_nested(): validate("(1, 2)", Tuple[Literal[1], Literal[2]]) + +def test_annotated(): + validate("Annotated[1, 1]", Annotated[Literal[1], 1]) diff --git a/tests/test_typetime_only/test_typetime_only.py b/tests/test_typetime_only/test_typetime_only.py index 49c6376..5971f19 100644 --- a/tests/test_typetime_only/test_typetime_only.py +++ b/tests/test_typetime_only/test_typetime_only.py @@ -5,4 +5,4 @@ def test_runtime_import() -> None: with raises(ImportError): - import basedtyping.typetime_only # noqa: F401 + pass