diff --git a/src/scanpydoc/elegant_typehints/_formatting.py b/src/scanpydoc/elegant_typehints/_formatting.py index 4dde8d3..dd1ef66 100644 --- a/src/scanpydoc/elegant_typehints/_formatting.py +++ b/src/scanpydoc/elegant_typehints/_formatting.py @@ -3,7 +3,10 @@ import inspect from collections.abc import Callable, Iterable, Mapping, Sequence from functools import partial -from typing import TYPE_CHECKING, Any, Literal, get_args, get_origin +from types import UnionType +from typing import TYPE_CHECKING, Any, Literal, Union, get_args, get_origin +from typing import Callable as t_Callable +from typing import Mapping as t_Mapping # noqa: UP035 from docutils import nodes from docutils.parsers.rst.roles import set_classes @@ -45,22 +48,19 @@ def _format_full(annotation: type[Any], config: Config) -> str | None: def _format_terse(annotation: type[Any], config: Config) -> str: - from collections.abc import Mapping as t_Mapping - from typing import Union - origin = get_origin(annotation) args = get_args(annotation) tilde = "" if config.typehints_fully_qualified else "~" fmt = partial(_format_terse, config=config) # display `Union[A, B]` as `A | B` - if origin is Union: + if origin in (Union, UnionType): # Never use the `Optional` keyword in the displayed docs. # Use `| None` instead, similar to other large numerical packages. return " | ".join(map(fmt, args)) # do not show the arguments of Mapping - if origin is Mapping or origin is t_Mapping: + if origin in (Mapping, t_Mapping): return f":py:class:`{tilde}collections.abc.Mapping`" # display dict as {k: v} @@ -69,7 +69,7 @@ def _format_terse(annotation: type[Any], config: Config) -> str: return f"{{{fmt(k)}: {fmt(v)}}}" # display Callable[[a1, a2], r] as (a1, a2) -> r - if origin is Callable and len(args) == 2: # noqa: PLR2004 + if origin in (Callable, t_Callable) and len(args) == 2: # noqa: PLR2004 params, ret = args params = ["…"] if params is Ellipsis else map(fmt, params) return f"({', '.join(params)}) → {fmt(ret)}" diff --git a/tests/test_elegant_typehints.py b/tests/test_elegant_typehints.py index e0946e6..c1a9e72 100644 --- a/tests/test_elegant_typehints.py +++ b/tests/test_elegant_typehints.py @@ -326,7 +326,7 @@ class B: ( tuple[Mapping[str, float], int], r":annotation-terse:`:py:class:\`~collections.abc.Mapping\``\ " - r":annotation-full:`:py:class:\`~typing.Mapping\`\[" + r":annotation-full:`:py:class:\`~collections.abc.Mapping\`\[" r":py:class:\`str\`, :py:class:\`float\`" r"]`", ),