Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
KotlinIsland committed Aug 20, 2024
1 parent e98c2b5 commit 0812db6
Show file tree
Hide file tree
Showing 8 changed files with 57 additions and 21 deletions.
2 changes: 1 addition & 1 deletion .idea/basedtyping.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 5 additions & 1 deletion basedtyping/__init__.py → src/basedtyping/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import typing_extensions
from typing_extensions import Never, ParamSpec, Self, TypeAlias, TypeGuard, TypeVarTuple

from basedtyping import transformer
from basedtyping.runtime_only import OldUnionType

if not TYPE_CHECKING:
Expand Down Expand Up @@ -606,6 +607,9 @@ class ForwardRef(typing.ForwardRef, _root=True): # type: ignore[call-arg,misc]
if the original syntax is not supported in the current Python version.
"""

# older typing.ForwardRef doesn't have this
__slots__ = ["__forward_module__", "__forward_is_class__"]

def __init__(self, arg, *, is_argument=True, module=None, is_class=False):
if not isinstance(arg, str):
raise TypeError(f"Forward reference must be a string -- got {arg!r}")
Expand Down Expand Up @@ -641,4 +645,4 @@ def _evaluate(
localns: dict[str, Any] | None,
recursive_guard: frozenset[str] | None = None,
) -> Any:
return typing.t_eval_direct(self, globalns, localns)
return transformer._eval_direct(self, globalns, localns)
File renamed without changes.
File renamed without changes.
59 changes: 45 additions & 14 deletions basedtyping/transformer.py → src/basedtyping/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@
from contextlib import contextmanager
from enum import Enum
from typing import Any
import basedtyping

import typing_extensions

import basedtyping
from basedtyping import ForwardRef

class CringeTransformer(ast.NodeTransformer):
"""
Expand Down Expand Up @@ -37,7 +40,8 @@ def __init__(
assert globalns is not None
localns = globalns

self.typing_name = f"typing_extensions_{uuid.uuid4().hex}"
self.typing_name = f"typing_extensions"
# self.typing_name = f"typing_extensions_{uuid.uuid4().hex}"
self.basedtyping_name = f"basedtyping_{uuid.uuid4().hex}"
self.globalns = globalns
import typing_extensions
Expand All @@ -53,16 +57,22 @@ def eval_type(
) -> object:
if not isinstance(node, ast.Expression):
node = ast.copy_location(ast.Expression(node), node)
ref = typing.ForwardRef(ast.dump(node))
ref = ForwardRef(ast.dump(node))
if original_ref:
for attr in ("is_argument", " is_class", "module"):
attr = f"__forward_{attr}__"
if hasattr(original_ref, attr):
setattr(ref, attr, getattr(original_ref, attr))
ref.__forward_code__ = compile(node, "<node>", "eval")
try:
setattr(ref, attr, getattr(original_ref, attr))
except AttributeError:
pass # older ForwardRefs don't have
try:
ref.__forward_code__ = compile(node, "<node>", "eval")
except Exception as err:
...
try:
return typing._eval_type(ref, self.globalns, self.localns)
except TypeError:
except TypeError as err:
return None

def _typing(self, attr: str):
Expand All @@ -81,7 +91,7 @@ def _literal(self, value: ast.Constant | ast.Name | ast.Attribute):
return self.subscript(self._typing("Literal"), value)

def subscript(self, value, slice):
result = ast.Subscript(value=value, slice=slice, ctx=ast.Load())
result = ast.Subscript(value=value, slice=ast.Index(slice), ctx=ast.Load())
return ast.fix_missing_locations(result)

_implicit_tuple = False
Expand All @@ -96,6 +106,10 @@ def implicit_tuple(self):
self._implicit_tuple = implicit_tuple

def visit_Subscript(self, node: ast.Subscript) -> ast.AST:
node_type = self.eval_type(node.value)
if node_type is typing_extensions.Annotated:
node.slice.value.elts[0] = self.visit(node.slice.value.elts[0])
return node
with self.implicit_tuple():
node = self.generic_visit(node)
# TODO: FunctionType -> Callable
Expand All @@ -104,12 +118,12 @@ def visit_Subscript(self, node: ast.Subscript) -> ast.AST:
node = self.subscript(self._typing("Callable"), node.slice)
return node

def visit_Attribute(self, node) -> ast.Name:
node = self.generic_visit(node)
node_type = self.eval_type(node)
if isinstance(node_type, Enum):
node = self._literal(node)
return node
# def visit_Attribute(self, node) -> ast.Name:
# node = self.generic_visit(node)
# node_type = self.eval_type(node)
# if isinstance(node_type, Enum):
# node = self._literal(node)
# return node

def visit_Name(self, node) -> ast.Name:
node = self.generic_visit(node)
Expand All @@ -120,7 +134,7 @@ def visit_Name(self, node) -> ast.Name:

def visit_Constant(self, node: ast.Constant) -> ast.AST:
node = self.generic_visit(node)
if isinstance(node.value, int | bool) or (
if isinstance(node.value, int) or (
self.string_literals and isinstance(node.value, str)
):
node = self._literal(node)
Expand Down Expand Up @@ -173,6 +187,23 @@ def _eval_direct(
return eval_type_based(value, globalns, localns, string_literals=False)


def crifigy_type(
value: str,
globalns: typing.Mapping[str, object] | None = None,
localns: typing.Mapping[str, object] | None = None,
*,
string_literals: bool,
) -> object:
try:
tree = ast.parse(value, mode="eval")
except SyntaxError:
tree = ast.parse(value.removeprefix("def").lstrip(), mode="func_type")

transformer = CringeTransformer(globalns, localns, string_literals=string_literals)
tree = transformer.visit(tree)
return ast.unparse(tree)


def eval_type_based(
value: object,
globalns: typing.Mapping[str, object] | None = None,
Expand Down
File renamed without changes.
9 changes: 5 additions & 4 deletions tests/test_transformer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from enum import Enum
from typing import Tuple

from transformer import eval_type_based
from typing_extensions import Callable, Literal, TypeIs, Union, TypeGuard
from types import FunctionType
from basedtyping.transformer import eval_type_based
from typing_extensions import Callable, Literal, TypeIs, Union, TypeGuard, Annotated
from basedtyping import ForwardRef, Intersection

# ruff: noqa: PYI030, PYI030
Expand All @@ -30,7 +29,6 @@ def test_literal_str():
validate("'int'", Literal["int"], string_literals=True)
validate("Literal['int']", Literal["int"], string_literals=True)


class E(Enum):
a = 1
b = 2
Expand Down Expand Up @@ -71,3 +69,6 @@ def test_intersection():

def test_nested():
validate("(1, 2)", Tuple[Literal[1], Literal[2]])

def test_annotated():
validate("Annotated[1, 1]", Annotated[Literal[1], 1])
2 changes: 1 addition & 1 deletion tests/test_typetime_only/test_typetime_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@

def test_runtime_import() -> None:
with raises(ImportError):
import basedtyping.typetime_only # noqa: F401
pass

0 comments on commit 0812db6

Please sign in to comment.