diff --git a/basedtyping/__init__.py b/basedtyping/__init__.py index a4d5bd1..a25bb31 100644 --- a/basedtyping/__init__.py +++ b/basedtyping/__init__.py @@ -72,6 +72,7 @@ "as_functiontype", "ForwardRef", "BASEDMYPY_TYPE_CHECKING", + "get_type_hints", ) if TYPE_CHECKING: @@ -743,3 +744,110 @@ def _type_check(arg: object, msg: str) -> object: if not callable(arg): raise TypeError(f"{msg} Got {arg!r:.100}.") return arg + + +_strip_annotations = typing._strip_annotations # type: ignore[attr-defined] + + +def get_type_hints( # type: ignore[no-any-explicit] + obj: object + | Callable[..., object] + | FunctionType[..., object] + | types.BuiltinFunctionType[..., object] + | types.MethodType + | types.ModuleType + | types.WrapperDescriptorType + | types.MethodWrapperType + | types.MethodDescriptorType, + globalns: dict[str, object] | None = None, + localns: dict[str, object] | None = None, + include_extras: bool = False, # noqa: FBT001, FBT002 +) -> dict[str, object]: + """Return type hints for an object. + + same as `typing.get_type_hints` except: + - supports based typing denotations + - adds the class to the scope: + + ```py + class Base: + def __init_subclass__(cls): + get_type_hints(cls) + + class A(Base): + a: A + ``` + """ + if getattr(obj, "__no_type_check__", None): # type: ignore[no-any-expr] + return {} + # Classes require a special treatment. + if isinstance(obj, type): # type: ignore[no-any-expr] + hints = {} + for base in reversed(obj.__mro__): + if globalns is None: + base_globals = getattr(sys.modules.get(base.__module__, None), "__dict__", {}) # type: ignore[no-any-expr] + else: + base_globals = globalns + ann = base.__dict__.get("__annotations__", {}) # type: ignore[no-any-expr] + if isinstance(ann, types.GetSetDescriptorType): # type: ignore[no-any-expr] + ann = {} # type: ignore[no-any-expr] + base_locals = dict(vars(base)) if localns is None else localns # type: ignore[no-any-expr] + if localns is None and globalns is None: + # This is surprising, but required. Before Python 3.10, + # get_type_hints only evaluated the globalns of + # a class. To maintain backwards compatibility, we reverse + # the globalns and localns order so that eval() looks into + # *base_globals* first rather than *base_locals*. + # This only affects ForwardRefs. + base_globals, base_locals = base_locals, base_globals + # start not copied section + if base is obj: + # add the class to the scope + base_locals[obj.__name__] = obj # type: ignore[no-any-expr] + # end not copied section + for name, value in ann.items(): # type: ignore[no-any-expr] + if value is None: # type: ignore[no-any-expr] + value = type(None) + if isinstance(value, str): # type: ignore[no-any-expr] + if sys.version_info < (3, 9): + value = ForwardRef(value, is_argument=False) + else: + value = ForwardRef(value, is_argument=False, is_class=True) + value = typing._eval_type(value, base_globals, base_locals, recursive_guard=1) # type: ignore[attr-defined, no-any-expr] + hints[name] = value # type: ignore[no-any-expr] + + return hints if include_extras else {k: _strip_annotations(t) for k, t in hints.items()} # type: ignore[no-any-expr] + + if globalns is None: + if isinstance(obj, types.ModuleType): # type: ignore[no-any-expr] + globalns = obj.__dict__ + else: + nsobj = obj + # Find globalns for the unwrapped object. + while hasattr(nsobj, "__wrapped__"): + nsobj = nsobj.__wrapped__ # type: ignore[no-any-expr] + globalns = getattr(nsobj, "__globals__", {}) # type: ignore[no-any-expr] + if localns is None: + localns = globalns + elif localns is None: + localns = globalns + hints = getattr(obj, "__annotations__", None) # type: ignore[assignment, no-any-expr] + if hints is None: # type: ignore[no-any-expr, redundant-expr] + # Return empty annotations for something that _could_ have them. + if isinstance(obj, typing._allowed_types): # type: ignore[ unreachable] + return {} + raise TypeError(f"{obj!r} is not a module, class, method, " "or function.") + hints = dict(hints) # type: ignore[no-any-expr] + for name, value in hints.items(): # type: ignore[no-any-expr] + if value is None: # type: ignore[no-any-expr] + value = type(None) + if isinstance(value, str): # type: ignore[no-any-expr] + # class-level forward refs were handled above, this must be either + # a module-level annotation or a function argument annotation + value = ForwardRef( + value, + is_argument=not isinstance(cast(object, obj), types.ModuleType), + is_class=False, + ) + hints[name] = typing._eval_type(value, globalns, localns) # type: ignore[no-any-expr, attr-defined] + return hints if include_extras else {k: _strip_annotations(t) for k, t in hints.items()} # type: ignore[no-any-expr] diff --git a/pyproject.toml b/pyproject.toml index ed5df6e..25c8bed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ authors = [ ] description = "Utilities for basedmypy" name = "basedtyping" -version = "0.1.8" +version = "0.1.9" [tool.poetry.dependencies] python = "^3.9" diff --git a/tests/test_get_type_hints.py b/tests/test_get_type_hints.py new file mode 100644 index 0000000..3552c36 --- /dev/null +++ b/tests/test_get_type_hints.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +import re + +from typing_extensions import Literal, Union, override + +from basedtyping import get_type_hints + + +def test_get_type_hints_class(): + result: object = None + + class Base: + @override + def __init_subclass__(cls): + nonlocal result + result = get_type_hints(cls) + + class A(Base): + a: A + + assert result == {"a": A} + + +def test_get_type_hints_based(): + class A: + a: Union[re.RegexFlag.ASCII, re.RegexFlag.DOTALL] + + assert get_type_hints(A) == { + "a": Union[Literal[re.RegexFlag.ASCII], Literal[re.RegexFlag.DOTALL]] # noqa: PYI030 + }