diff --git a/src/basedtyping/__init__.py b/src/basedtyping/__init__.py index 4c58362..069819e 100644 --- a/src/basedtyping/__init__.py +++ b/src/basedtyping/__init__.py @@ -624,8 +624,10 @@ def __init__(self, arg: str, *, is_argument=True, module: object=None, is_class= try: code = compile(arg_to_compile, "", "eval") except SyntaxError: + if arg.startswith("def"): + arg = arg[3:].lstrip() code = compile( - ast.parse(arg.removeprefix("def").lstrip(), mode="func_type"), + ast.parse(arg, mode="func_type"), "", "func_type", ast.PyCF_ONLY_AST, @@ -641,8 +643,10 @@ def __init__(self, arg: str, *, is_argument=True, module: object=None, is_class= def _evaluate( self, - globalns: dict[str, Any] | None, - localns: dict[str, Any] | None, + globalns: typing.Mapping[str, object] | None, + localns: typing.Mapping[str, object] | None, + type_params: object = None, + *, recursive_guard: frozenset[str] | None = None, ) -> Any: return transformer._eval_direct(self, globalns, localns) diff --git a/src/basedtyping/transformer.py b/src/basedtyping/transformer.py index 1abc859..62dea21 100644 --- a/src/basedtyping/transformer.py +++ b/src/basedtyping/transformer.py @@ -8,12 +8,10 @@ import uuid from contextlib import contextmanager from enum import Enum -from typing import Any import typing_extensions import basedtyping -from basedtyping import ForwardRef class CringeTransformer(ast.NodeTransformer): """ @@ -61,7 +59,7 @@ def eval_type( ) -> object: if not isinstance(node, ast.Expression): node = ast.copy_location(ast.Expression(node), node) - ref = ForwardRef(ast.dump(node)) + ref = typing.ForwardRef(ast.dump(node)) if original_ref: for attr in ("is_argument", " is_class", "module"): attr = f"__forward_{attr}__" @@ -79,22 +77,22 @@ def eval_type( except TypeError as err: return None - def _typing(self, attr: str): + def _typing(self, attr: str) -> ast.Attribute: result = ast.Attribute( value=ast.Name(id=self.typing_name, ctx=ast.Load()), attr=attr, ctx=ast.Load() ) return ast.fix_missing_locations(result) - def _basedtyping(self, attr: str): + def _basedtyping(self, attr: str) -> ast.Attribute: result = ast.Attribute( value=ast.Name(id=self.basedtyping_name, ctx=ast.Load()), attr=attr, ctx=ast.Load() ) return ast.fix_missing_locations(result) - def _literal(self, value: ast.Constant | ast.Name | ast.Attribute): + def _literal(self, value: ast.Constant | ast.Name | ast.Attribute) -> ast.Subscript: return self.subscript(self._typing("Literal"), value) - def subscript(self, value, slice): + def subscript(self, value: ast.AST, slice: ast.AST) -> ast.Subscript: result = ast.Subscript(value=value, slice=ast.Index(slice), ctx=ast.Load()) return ast.fix_missing_locations(result) @@ -122,26 +120,27 @@ def visit_Subscript(self, node: ast.Subscript) -> ast.AST: node = self.subscript(self._typing("Callable"), node.slice) return node - def visit_Attribute(self, node) -> ast.Name: + def visit_Attribute(self, node: ast.Attribute) -> ast.AST: node = self.generic_visit(node) + assert isinstance(node, ast.expr) node_type = self.eval_type(node) if isinstance(node_type, Enum): - node = self._literal(node) + assert isinstance(node, (ast.Name, ast.Attribute)) + return self._literal(node) return node - def visit_Name(self, node) -> ast.Name: - node = self.generic_visit(node) + def visit_Name(self, node: ast.Name) -> ast.AST: name_type = self.eval_type(node) if isinstance(name_type, Enum): - node = self._literal(node) + return self._literal(node) return node def visit_Constant(self, node: ast.Constant) -> ast.AST: - node = self.generic_visit(node) - if isinstance(node.value, int) or ( - self.string_literals and isinstance(node.value, str) + value = typing.cast(object, node.value) + if isinstance(value, int) or ( + self.string_literals and isinstance(value, str) ): - node = self._literal(node) + return self._literal(node) return node def visit_Tuple(self, node: ast.Tuple) -> ast.AST: @@ -152,7 +151,8 @@ def visit_Tuple(self, node: ast.Tuple) -> ast.AST: def visit_Compare(self, node: ast.Compare) -> ast.AST: if len(node.ops) == 1 and isinstance(node.ops[0], ast.Is): - node = self.subscript(self._typing("TypeIs"), self.generic_visit(node.comparators[0])) + result = self.subscript(self._typing("TypeIs"), self.generic_visit(node.comparators[0])) + return self.generic_visit(result) return self.generic_visit(node) def visit_IfExp(self, node: ast.IfExp) -> ast.AST: @@ -168,6 +168,7 @@ def visit_IfExp(self, node: ast.IfExp) -> ast.AST: def visit_FunctionType(self, node: ast.FunctionType) -> ast.AST: node = self.generic_visit(node) + assert isinstance(node, ast.FunctionType) return self.subscript( self._typing("Callable"), ast.Tuple([ast.List(node.argtypes, ctx=ast.Load()), node.returns], ctx=ast.Load()),