From d6872e3e3acbdebfa7b8da485e564612d20e875c Mon Sep 17 00:00:00 2001 From: Tobias Petersen Date: Tue, 14 Jan 2025 19:41:07 +0100 Subject: [PATCH] Accept generic ExceptionGroups for raises Closes #13115 --- AUTHORS | 1 + changelog/13115.improvement.rst | 1 + src/_pytest/python_api.py | 45 +++++++++++++++++++++++++++++---- testing/code/test_excinfo.py | 27 ++++++++++++++++++++ 4 files changed, 69 insertions(+), 5 deletions(-) create mode 100644 changelog/13115.improvement.rst diff --git a/AUTHORS b/AUTHORS index 8a1a7d183a..8600735c8b 100644 --- a/AUTHORS +++ b/AUTHORS @@ -435,6 +435,7 @@ Tim Hoffmann Tim Strazny TJ Bruno Tobias Diez +Tobias Petersen Tom Dalton Tom Viner Tomáš Gavenčiak diff --git a/changelog/13115.improvement.rst b/changelog/13115.improvement.rst new file mode 100644 index 0000000000..c77383c154 --- /dev/null +++ b/changelog/13115.improvement.rst @@ -0,0 +1 @@ +Allows supplying ``ExceptionGroup[Exception]`` and ``BaseExceptionGroup[BaseException]`` to ``pytest.raises`` to keep full typing on ExcInfo. diff --git a/src/_pytest/python_api.py b/src/_pytest/python_api.py index 25cf9f04d6..97e673f43b 100644 --- a/src/_pytest/python_api.py +++ b/src/_pytest/python_api.py @@ -12,10 +12,13 @@ from numbers import Complex import pprint import re +import sys from types import TracebackType from typing import Any from typing import cast from typing import final +from typing import get_args +from typing import get_origin from typing import overload from typing import TYPE_CHECKING from typing import TypeVar @@ -24,6 +27,10 @@ from _pytest.outcomes import fail +if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup + from exceptiongroup import ExceptionGroup + if TYPE_CHECKING: from numpy import ndarray @@ -954,15 +961,43 @@ def raises( f"Raising exceptions is already understood as failing the test, so you don't need " f"any special code to say 'this should never raise an exception'." ) + + expected_exceptions: tuple[type[E], ...] + origin_exc: type[E] | None = get_origin(expected_exception) if isinstance(expected_exception, type): - expected_exceptions: tuple[type[E], ...] = (expected_exception,) + expected_exceptions = (expected_exception,) + elif origin_exc and issubclass(origin_exc, BaseExceptionGroup): + expected_exceptions = (cast(type[E], expected_exception),) else: expected_exceptions = expected_exception - for exc in expected_exceptions: - if not isinstance(exc, type) or not issubclass(exc, BaseException): + + def validate_exc(exc: type[E]) -> type[E]: + origin_exc: type[E] | None = get_origin(exc) + if origin_exc and issubclass(origin_exc, BaseExceptionGroup): + exc_type = get_args(exc)[0] + if issubclass(origin_exc, ExceptionGroup) and exc_type is Exception: + return cast(type[E], origin_exc) + elif ( + issubclass(origin_exc, BaseExceptionGroup) and exc_type is BaseException + ): + return cast(type[E], origin_exc) + else: + raise ValueError( + f"Only `ExceptionGroup[Exception]` or `BaseExceptionGroup[BaseExeption]` " + f"are accepted as generic types but got `{exc}`. " + f"As `raises` will catch all instances of the specified group regardless of the " + f"generic argument specific nested exceptions has to be checked " + f"with `ExceptionInfo.group_contains()`" + ) + + elif not isinstance(exc, type) or not issubclass(exc, BaseException): msg = "expected exception must be a BaseException type, not {}" # type: ignore[unreachable] not_a = exc.__name__ if isinstance(exc, type) else type(exc).__name__ raise TypeError(msg.format(not_a)) + else: + return exc + + expected_exceptions = tuple(validate_exc(exc) for exc in expected_exceptions) message = f"DID NOT RAISE {expected_exception}" @@ -973,14 +1008,14 @@ def raises( msg += ", ".join(sorted(kwargs)) msg += "\nUse context-manager form instead?" raise TypeError(msg) - return RaisesContext(expected_exception, message, match) + return RaisesContext(expected_exceptions, message, match) else: func = args[0] if not callable(func): raise TypeError(f"{func!r} object (type: {type(func)}) must be callable") try: func(*args[1:], **kwargs) - except expected_exception as e: + except expected_exceptions as e: return _pytest._code.ExceptionInfo.from_exception(e) fail(message) diff --git a/testing/code/test_excinfo.py b/testing/code/test_excinfo.py index 22e695977e..ae2f208417 100644 --- a/testing/code/test_excinfo.py +++ b/testing/code/test_excinfo.py @@ -31,6 +31,7 @@ from _pytest._code.code import TracebackStyle if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup from exceptiongroup import ExceptionGroup @@ -453,6 +454,32 @@ def test_division_zero(): result.stdout.re_match_lines([r".*__tracebackhide__ = True.*", *match]) +def test_raises_accepts_generic_group() -> None: + exc_group = ExceptionGroup("", [RuntimeError()]) + with pytest.raises(ExceptionGroup[Exception]) as exc_info: + raise exc_group + assert exc_info.group_contains(RuntimeError) + + +def test_raises_accepts_generic_base_group() -> None: + exc_group = ExceptionGroup("", [RuntimeError()]) + with pytest.raises(BaseExceptionGroup[BaseException]) as exc_info: + raise exc_group + assert exc_info.group_contains(RuntimeError) + + +def test_raises_rejects_specific_generic_group() -> None: + with pytest.raises(ValueError): + pytest.raises(ExceptionGroup[RuntimeError]) + + +def test_raises_accepts_generic_group_in_tuple() -> None: + exc_group = ExceptionGroup("", [RuntimeError()]) + with pytest.raises((ValueError, ExceptionGroup[Exception])) as exc_info: + raise exc_group + assert exc_info.group_contains(RuntimeError) + + class TestGroupContains: def test_contains_exception_type(self) -> None: exc_group = ExceptionGroup("", [RuntimeError()])