Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve [Async]ContextDecorator type hinting #13416

Merged
merged 8 commits into from
Jan 20, 2025
Merged

Conversation

ncoghlan
Copy link
Contributor

@ncoghlan ncoghlan commented Jan 20, 2025

Updated annotations allow access to __wrapped__ on decorated callables without complaints from typecheckers.

They also correctly indicate that the returned type is NOT the same as the input type.
Instead, it is a different callable type with the same call signature.

Closes #13403

Updated annotations allow access to `__wrapped__` on
decorated callables without complaints from typecheckers.
@ncoghlan
Copy link
Contributor Author

MyPy is not a fan:

stubs/decorator/decorator.pyi:69: error: Signature of "__call__" incompatible with supertype "ContextDecorator"  [override]
stubs/decorator/decorator.pyi:69: note:      Superclass:
stubs/decorator/decorator.pyi:69: note:          def [_P`-1, _R] __call__(self, func: Callable[_P, _R]) -> _WrappedCallable[_P, _R]
stubs/decorator/decorator.pyi:69: note:      Subclass:
stubs/decorator/decorator.pyi:69: note:          def [_C: Callable[..., Any]] __call__(self, func: _C) -> _C

This comment has been minimized.

This comment has been minimized.

This comment has been minimized.

@@ -64,9 +64,13 @@ class AbstractAsyncContextManager(ABC, Protocol[_T_co, _ExitT_co]): # type: ign
self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, /
) -> _ExitT_co: ...

class _WrappedCallable(Callable[_P, _R]):
Copy link
Collaborator

@srittau srittau Jan 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was my bad. Callable is a special form, not a class, but the following could work:

Suggested change
class _WrappedCallable(Callable[_P, _R]):
class _WrappedCallable(Generic[_P, _R]):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll switch back to Generic to show the jax failure that version reported in the mypy_primer comment (mypy doesn't seem to consider _WrappedCallable[_P, _R] as conforming to _Callable[_P, _R])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although, thinking about it, the old signature here was genuinely wrong, since it claimed the returned callable type would be the same type as the input (and hence have any extra attributes, etc, that a callable subclass defined), rather than just returning a new callable with a compatible calling signature.

The newly introduced mypy complaints in jax reflect that it had passed along that same erroneous claim.

Copy link
Contributor

Diff from mypy_primer, showing the effect of this PR on open source code:

jax (https://github.com/google/jax)
+ jax/_src/api.py:2685: error: Incompatible return value type (got "_WrappedCallable[[VarArg(Any), KwArg(Any)], Any]", expected "F")  [return-value]
+ jax/experimental/jax2tf/jax2tf.py:3062: error: Argument 1 has incompatible type "_GeneratorContextManager[NameStack, None, None]"; expected "Callable[[partial[Sequence[Any]]], partial[Sequence[Any]]]"  [arg-type]
+ jax/experimental/jax2tf/jax2tf.py:3062: note: "_GeneratorContextManager[NameStack, None, None].__call__" has type "Callable[[Arg(Callable[_P, _R], 'func')], _WrappedCallable[_P, _R]]"

cki-lib (https://gitlab.com/cki-project/cki-lib)
- tests/test_metrics.py:75: error: Call to untyped function "dummy_function" in typed context  [no-untyped-call]

@ncoghlan
Copy link
Contributor Author

ncoghlan commented Jan 20, 2025

These are the new type error reports that this change introduces in jax:

jax (https://github.com/google/jax)
+ jax/_src/api.py:2685: error: Incompatible return value type (got "_WrappedCallable[[VarArg(Any), KwArg(Any)], Any]", expected "F")  [return-value]
+ jax/experimental/jax2tf/jax2tf.py:3062: error: Argument 1 has incompatible type "_GeneratorContextManager[NameStack, None, None]"; expected "Callable[[partial[Sequence[Any]]], partial[Sequence[Any]]]"  [arg-type]
+ jax/experimental/jax2tf/jax2tf.py:3062: note: "_GeneratorContextManager[NameStack, None, None].__call__" has type "Callable[[Arg(Callable[_P, _R], 'func')], _WrappedCallable[_P, _R]]"

The first one is for a generic function with a F -> F signature along the same lines as the existing ContextDecorator.__call__ signature, which overpromises when it comes to the available functionality on the returned type. With this patch, MyPy now complains about that case (while neither typing nor typing_extensions provide a public WrappedCallable protocol to use to type it fully, it can at least be more correctly typed as accepting and emitting Callable[P, R] rather than using a single typevar).

The second one is related, but with slightly different symptoms, as it arises due to mypy inferring a narrow type for a sequence populated with partial objects, and now correctly detecting that the subsequent sequence transformation changes the type of the callables in the sequence. The fix would be to either use different names for the two sequences, or else declare the broader type for the first sequence.

Comment on lines -68 to -69
class ContextManager(_GeneratorContextManager[_T]):
def __call__(self, func: _C) -> _C: ...
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you remove __call__ here? The decorator library explicitly adds it and it's not part of the inheritance tree:

https://github.com/micheles/decorator/blob/519cc713878df4bfb768edfaa65e4bf2d875e421/src/decorator.py#L314-L318

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the stub tests failed complaining that the __call__ definition here no longer conformed to the parent class declaration of ContextDecorator.__call__ (it was the same issue as in jax: the return type isn't the same as the input type, it just has the same call signature).

Deleting the stub in the subclass means it inherits the ContextDecorator.__call__ signature, which is also correct for decorator.ContextManager.

I didn't check if the nominal signature of decorator.decorate itself might need updating, though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I edited the PR description to include that additional information about the change fixing a genuine bug in the old signature definition.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whoops, I missed some of the base classes of _GeneratorContextManager when checking (twice!).

Copy link
Collaborator

@srittau srittau left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@srittau srittau merged commit 57d7c43 into python:main Jan 20, 2025
58 checks passed
@ncoghlan
Copy link
Contributor Author

I submitted a follow up PR to jax here: jax-ml/jax#25994

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Improve type hinting for ContextDecorator and AsyncContextDecorator
2 participants