Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
KotlinIsland committed Aug 28, 2024
1 parent 9703c44 commit f792727
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 47 deletions.
23 changes: 9 additions & 14 deletions src/basedtyping/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ def f[T](t: TypeForm[T]) -> T: ...

# TODO: conditionally declare FunctionType with a BASEDMYPY so that this doesn't break everyone else
# https://github.com/KotlinIsland/basedmypy/issues/524
def as_functiontype(fn: Callable[P, T]) -> FunctionType[P, T]: # type: ignore[type-arg]
def as_functiontype(fn: Callable[P, T]) -> FunctionType[P, T]:
"""Asserts that a ``Callable`` is a ``FunctionType`` and returns it
best used as a decorator to fix other incorrectly typed decorators:
Expand All @@ -594,9 +594,10 @@ def deco(fn: Callable[[], None]) -> Callable[[], None]: ...
@deco
def foo(): ...
"""
if not isinstance(fn, FunctionType): # type: ignore[redundant-expr]
if not isinstance(fn, FunctionType):
raise TypeError(f"{fn} is not a FunctionType")
return fn # type: ignore[unreachable]
# https://github.com/KotlinIsland/basedmypy/issues/745
return cast("FunctionType[P, T]", fn)


class ForwardRef(typing.ForwardRef, _root=True): # type: ignore[call-arg,misc]
Expand All @@ -610,8 +611,8 @@ class ForwardRef(typing.ForwardRef, _root=True): # type: ignore[call-arg,misc]
# older typing.ForwardRef doesn't have this
__slots__ = ["__forward_module__", "__forward_is_class__"]

def __init__(self, arg: str, *, is_argument=True, module: object=None, is_class=False):
if not isinstance(arg, str):
def __init__(self, arg: str, *, is_argument=True, module: object = None, is_class=False):
if not isinstance(arg, str): # type: ignore[redundant-expr]
raise TypeError(f"Forward reference must be a string -- got {arg!r}")

# If we do `def f(*args: *Ts)`, then we'll have `arg = '*Ts'`.
Expand All @@ -624,14 +625,8 @@ def __init__(self, arg: str, *, is_argument=True, module: object=None, is_class=
try:
code = compile(arg_to_compile, "<string>", "eval")
except SyntaxError:
if arg.startswith("def"):
arg = arg[3:].lstrip()
code = compile(
ast.parse(arg, mode="func_type"),
"<string>",
"func_type",
ast.PyCF_ONLY_AST,
)
# Callable: () -> int
code = compile("1", "<string>", "eval")

self.__forward_arg__ = arg
self.__forward_code__ = code
Expand All @@ -648,5 +643,5 @@ def _evaluate(
type_params: object = None,
*,
recursive_guard: frozenset[str] | None = None,
) -> Any:
) -> object:
return transformer._eval_direct(self, globalns, localns)
74 changes: 47 additions & 27 deletions src/basedtyping/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import typing_extensions

import basedtyping
from typing import cast


class CringeTransformer(ast.NodeTransformer):
"""
Expand Down Expand Up @@ -52,29 +54,27 @@ def __init__(
}

def visit(self, node: ast.AST) -> ast.AST:
return typing.cast(ast.AST, super().visit(node))
return cast(ast.AST, super().visit(node))

def eval_type(
self, node: ast.FunctionType | ast.Expression | ast.expr, *, original_ref: typing.ForwardRef | None = None
self,
node: ast.FunctionType | ast.Expression | ast.expr,
*,
original_ref: typing.ForwardRef | None = None,
) -> object:
if not isinstance(node, ast.Expression):
if isinstance(node, ast.expr):
node = ast.copy_location(ast.Expression(node), node)
ref = typing.ForwardRef(ast.dump(node))
if original_ref:
for attr in ("is_argument", " is_class", "module"):
attr = f"__forward_{attr}__"
if hasattr(original_ref, attr):
try:
setattr(ref, attr, getattr(original_ref, attr))
except AttributeError:
pass # older ForwardRefs don't have
try:
setattr(ref, attr, cast(object, getattr(original_ref, attr)))
if not isinstance(node, ast.FunctionType):
ref.__forward_code__ = compile(node, "<node>", "eval")
except Exception as err:
...
try:
return typing._eval_type(ref, self.globalns, self.localns)
except TypeError as err:
return typing._eval_type(ref, self.globalns, self.localns) # type: ignore[attr-defined]
except TypeError :
return None

def _typing(self, attr: str) -> ast.Attribute:
Expand All @@ -92,14 +92,14 @@ def _basedtyping(self, attr: str) -> ast.Attribute:
def _literal(self, value: ast.Constant | ast.Name | ast.Attribute) -> ast.Subscript:
return self.subscript(self._typing("Literal"), value)

def subscript(self, value: ast.AST, slice: ast.AST) -> ast.Subscript:
result = ast.Subscript(value=value, slice=ast.Index(slice), ctx=ast.Load())
def subscript(self, value: ast.expr, slice_: ast.expr) -> ast.Subscript:
result = ast.Subscript(value=value, slice=ast.Index(slice_), ctx=ast.Load())
return ast.fix_missing_locations(result)

_implicit_tuple = False

@contextmanager
def implicit_tuple(self):
def implicit_tuple(self) -> typing.Iterator[None]:
implicit_tuple = self._implicit_tuple
self._implicit_tuple = True
try:
Expand All @@ -110,14 +110,31 @@ def implicit_tuple(self):
def visit_Subscript(self, node: ast.Subscript) -> ast.AST:
node_type = self.eval_type(node.value)
if node_type is typing_extensions.Annotated:
node.slice.elts[0] = self.visit(node.slice.elts[0])
if sys.version_info < (3, 9):
slice_ = cast(ast.Index, node.slice)
if isinstance(slice_.value, ast.Tuple):
slice_.value.elts[0] = cast(ast.expr, self.visit(slice_.value.elts[0]))
else:
slice_.value = cast(ast.expr, self.visit(slice_.value))
else:
slice_ = node.slice
if isinstance(slice_, ast.Tuple):
slice_.elts[0] = self.visit(slice_.elts[0])
else:
node.slice = self.visit(slice_)
return node
with self.implicit_tuple():
node = self.generic_visit(node)
# TODO: FunctionType -> Callable
result = self.generic_visit(node)
assert isinstance(result, ast.Subscript)
node = result

node_type = self.eval_type(node.value)
if node_type is types.FunctionType:
node = self.subscript(self._typing("Callable"), node.slice)
if sys.version_info < (3, 9):
slice2_ = cast(ast.Index, node.slice).value
else:
slice2_ = node.slice
node = self.subscript(self._typing("Callable"), slice2_)
return node

def visit_Attribute(self, node: ast.Attribute) -> ast.AST:
Expand All @@ -136,22 +153,22 @@ def visit_Name(self, node: ast.Name) -> ast.AST:
return node

def visit_Constant(self, node: ast.Constant) -> ast.AST:
value = typing.cast(object, node.value)
if isinstance(value, int) or (
self.string_literals and isinstance(value, str)
):
value = cast(object, node.value)
if isinstance(value, int) or (self.string_literals and isinstance(value, str)):
return self._literal(node)
return node

def visit_Tuple(self, node: ast.Tuple) -> ast.AST:
node = self.generic_visit(node)
if not self._implicit_tuple:
return self.subscript(self._typing("Tuple"), node)
return self.subscript(self._typing("Tuple"), cast(ast.expr, node))
return node

def visit_Compare(self, node: ast.Compare) -> ast.AST:
if len(node.ops) == 1 and isinstance(node.ops[0], ast.Is):
result = self.subscript(self._typing("TypeIs"), self.generic_visit(node.comparators[0]))
result = self.subscript(
self._typing("TypeIs"), cast(ast.expr, self.generic_visit(node.comparators[0]))
)
return self.generic_visit(result)
return self.generic_visit(node)

Expand All @@ -162,13 +179,14 @@ def visit_IfExp(self, node: ast.IfExp) -> ast.AST:
and isinstance(node.body.ops[0], ast.Is)
):
node.body = self.subscript(
self._typing("TypeGuard"), self.generic_visit(node.body.comparators[0])
self._typing("TypeGuard"),
cast(ast.expr, self.generic_visit(node.body.comparators[0])),
)
return self.generic_visit(node)

def visit_FunctionType(self, node: ast.FunctionType) -> ast.AST:
node = self.generic_visit(node)
assert isinstance(node, ast.FunctionType)
assert isinstance(node, ast.FunctionType)
return self.subscript(
self._typing("Callable"),
ast.Tuple([ast.List(node.argtypes, ctx=ast.Load()), node.returns], ctx=ast.Load()),
Expand All @@ -191,7 +209,9 @@ def _eval_direct(
) -> object:
return eval_type_based(value, globalns, localns, string_literals=False)


if sys.version_info >= (3, 9):

def crifigy_type(
value: str,
globalns: typing.Mapping[str, object] | None = None,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_intersection.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_intersection_eq():

def test_intersection_eq_hash():
assert hash(value) == hash(value)
assert hash(value) != other # type: ignore[comparison-overlap]
assert hash(value) != other # type: ignore[comparison-overlap]


def test_intersection_instancecheck():
Expand Down
5 changes: 1 addition & 4 deletions tests/test_is_subform.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,7 @@ def test_old_union():
# TODO: fix the mypy error # noqa: TD003
assert not issubform(Union[int, str], int)
assert issubform(Union[int, str], object)
assert issubform(
Union[int, str],
Union[str, int],
)
assert issubform(Union[int, str], Union[str, int])
if sys.version_info >= (3, 10):
assert issubform(
Union[int, str], # type: ignore[arg-type]
Expand Down
8 changes: 7 additions & 1 deletion tests/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@

def validate(value: str, expected: object, *, string_literals=False):
assert (
eval_type_based(ForwardRef(value), globalns=cast(Dict[str, object], globals()), string_literals=string_literals)
eval_type_based(
ForwardRef(value),
globalns=cast(Dict[str, object], globals()),
string_literals=string_literals,
)
== expected
)

Expand All @@ -30,6 +34,7 @@ def test_literal_str():
validate("'int'", Literal["int"], string_literals=True)
validate("Literal['int']", Literal["int"], string_literals=True)


class E(Enum):
a = 1
b = 2
Expand Down Expand Up @@ -71,5 +76,6 @@ def test_intersection():
def test_nested():
validate("(1, 2)", Tuple[Literal[1], Literal[2]])


def test_annotated():
validate("Annotated[1, 1]", Annotated[Literal[1], 1])

0 comments on commit f792727

Please sign in to comment.