Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
KotlinIsland committed Aug 24, 2024
1 parent 3a44a06 commit 9703c44
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 20 deletions.
10 changes: 7 additions & 3 deletions src/basedtyping/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,8 +624,10 @@ def __init__(self, arg: str, *, is_argument=True, module: object=None, is_class=
try:
code = compile(arg_to_compile, "<string>", "eval")
except SyntaxError:
if arg.startswith("def"):
arg = arg[3:].lstrip()
code = compile(
ast.parse(arg.removeprefix("def").lstrip(), mode="func_type"),
ast.parse(arg, mode="func_type"),
"<string>",
"func_type",
ast.PyCF_ONLY_AST,
Expand All @@ -641,8 +643,10 @@ def __init__(self, arg: str, *, is_argument=True, module: object=None, is_class=

def _evaluate(
self,
globalns: dict[str, Any] | None,
localns: dict[str, Any] | None,
globalns: typing.Mapping[str, object] | None,
localns: typing.Mapping[str, object] | None,
type_params: object = None,
*,
recursive_guard: frozenset[str] | None = None,
) -> Any:
return transformer._eval_direct(self, globalns, localns)
35 changes: 18 additions & 17 deletions src/basedtyping/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,10 @@
import uuid
from contextlib import contextmanager
from enum import Enum
from typing import Any

import typing_extensions

import basedtyping
from basedtyping import ForwardRef

class CringeTransformer(ast.NodeTransformer):
"""
Expand Down Expand Up @@ -61,7 +59,7 @@ def eval_type(
) -> object:
if not isinstance(node, ast.Expression):
node = ast.copy_location(ast.Expression(node), node)
ref = ForwardRef(ast.dump(node))
ref = typing.ForwardRef(ast.dump(node))
if original_ref:
for attr in ("is_argument", " is_class", "module"):
attr = f"__forward_{attr}__"
Expand All @@ -79,22 +77,22 @@ def eval_type(
except TypeError as err:
return None

def _typing(self, attr: str):
def _typing(self, attr: str) -> ast.Attribute:
result = ast.Attribute(
value=ast.Name(id=self.typing_name, ctx=ast.Load()), attr=attr, ctx=ast.Load()
)
return ast.fix_missing_locations(result)

def _basedtyping(self, attr: str):
def _basedtyping(self, attr: str) -> ast.Attribute:
result = ast.Attribute(
value=ast.Name(id=self.basedtyping_name, ctx=ast.Load()), attr=attr, ctx=ast.Load()
)
return ast.fix_missing_locations(result)

def _literal(self, value: ast.Constant | ast.Name | ast.Attribute):
def _literal(self, value: ast.Constant | ast.Name | ast.Attribute) -> ast.Subscript:
return self.subscript(self._typing("Literal"), value)

def subscript(self, value, slice):
def subscript(self, value: ast.AST, slice: ast.AST) -> ast.Subscript:
result = ast.Subscript(value=value, slice=ast.Index(slice), ctx=ast.Load())
return ast.fix_missing_locations(result)

Expand Down Expand Up @@ -122,26 +120,27 @@ def visit_Subscript(self, node: ast.Subscript) -> ast.AST:
node = self.subscript(self._typing("Callable"), node.slice)
return node

def visit_Attribute(self, node) -> ast.Name:
def visit_Attribute(self, node: ast.Attribute) -> ast.AST:
node = self.generic_visit(node)
assert isinstance(node, ast.expr)
node_type = self.eval_type(node)
if isinstance(node_type, Enum):
node = self._literal(node)
assert isinstance(node, (ast.Name, ast.Attribute))
return self._literal(node)
return node

def visit_Name(self, node) -> ast.Name:
node = self.generic_visit(node)
def visit_Name(self, node: ast.Name) -> ast.AST:
name_type = self.eval_type(node)
if isinstance(name_type, Enum):
node = self._literal(node)
return self._literal(node)
return node

def visit_Constant(self, node: ast.Constant) -> ast.AST:
node = self.generic_visit(node)
if isinstance(node.value, int) or (
self.string_literals and isinstance(node.value, str)
value = typing.cast(object, node.value)
if isinstance(value, int) or (
self.string_literals and isinstance(value, str)
):
node = self._literal(node)
return self._literal(node)
return node

def visit_Tuple(self, node: ast.Tuple) -> ast.AST:
Expand All @@ -152,7 +151,8 @@ def visit_Tuple(self, node: ast.Tuple) -> ast.AST:

def visit_Compare(self, node: ast.Compare) -> ast.AST:
if len(node.ops) == 1 and isinstance(node.ops[0], ast.Is):
node = self.subscript(self._typing("TypeIs"), self.generic_visit(node.comparators[0]))
result = self.subscript(self._typing("TypeIs"), self.generic_visit(node.comparators[0]))
return self.generic_visit(result)
return self.generic_visit(node)

def visit_IfExp(self, node: ast.IfExp) -> ast.AST:
Expand All @@ -168,6 +168,7 @@ def visit_IfExp(self, node: ast.IfExp) -> ast.AST:

def visit_FunctionType(self, node: ast.FunctionType) -> ast.AST:
node = self.generic_visit(node)
assert isinstance(node, ast.FunctionType)
return self.subscript(
self._typing("Callable"),
ast.Tuple([ast.List(node.argtypes, ctx=ast.Load()), node.returns], ctx=ast.Load()),
Expand Down

0 comments on commit 9703c44

Please sign in to comment.