diff --git a/acacore/utils/helpers.py b/acacore/utils/helpers.py index 8044693..9287919 100644 --- a/acacore/utils/helpers.py +++ b/acacore/utils/helpers.py @@ -1,5 +1,6 @@ from types import TracebackType from typing import Optional +from typing import Sequence from typing import Type @@ -14,14 +15,16 @@ class ExceptionManager: exception (Optional[BaseException]): The exception that was raised within the context, if any. traceback (Optional[TracebackType]): The traceback associated with the exception, if any. catch (tuple[Type[BaseException], ...]): Tuple of exceptions that should be caught instead of letting them rise. + allow (tuple[Type[BaseException], ...]): Tuple of exceptions that should be allowed to rise. """ - __slots__ = ("exception", "traceback", "catch") + __slots__ = ("exception", "traceback", "catch", "allow") - def __init__(self, *catch: Type[BaseException]) -> None: + def __init__(self, *catch: Type[BaseException], allow: Optional[Sequence[Type[BaseException]]] = None) -> None: self.exception: Optional[BaseException] = None self.traceback: Optional[TracebackType] = None self.catch: tuple[Type[BaseException], ...] = catch + self.allow: tuple[Type[BaseException], ...] = tuple(allow or []) def __enter__(self) -> "ExceptionManager": return self @@ -34,4 +37,8 @@ def __exit__( ) -> bool: self.exception = exc_val self.traceback = exc_tb - return any(issubclass(exc_type, e) for e in self.catch) if exc_type else False + + if not exc_type: + return False + + return any(issubclass(exc_type, e) for e in self.catch) and not any(issubclass(exc_type, e) for e in self.allow)