diff --git a/exceptiongroup/__init__.py b/exceptiongroup/__init__.py index ef2a381..d560757 100644 --- a/exceptiongroup/__init__.py +++ b/exceptiongroup/__init__.py @@ -33,6 +33,7 @@ def __init__(self, message, exceptions, sources): raise TypeError( "Expected an exception object, not {!r}".format(exc) ) + self.message = message self.sources = list(sources) if len(self.sources) != len(self.exceptions): raise ValueError( @@ -41,6 +42,16 @@ def __init__(self, message, exceptions, sources): ) ) + # copy.copy doesn't work for ExceptionGroup, because BaseException have + # rewrite __reduce_ex__ method. We need to add __copy__ method to + # make it can be copied. + def __copy__(self): + new_group = self.__class__(self.message, self.exceptions, self.sources) + new_group.__traceback__ = self.__traceback__ + new_group.__context__ = self.__context__ + new_group.__cause__ = self.__cause__ + return new_group + from . import _monkeypatch from ._tools import split, catch diff --git a/exceptiongroup/_tests/test_exceptiongroup.py b/exceptiongroup/_tests/test_exceptiongroup.py deleted file mode 100644 index e67928c..0000000 --- a/exceptiongroup/_tests/test_exceptiongroup.py +++ /dev/null @@ -1,2 +0,0 @@ -import traceback -from exceptiongroup import ExceptionGroup, split, catch diff --git a/exceptiongroup/_tests/test_tools.py b/exceptiongroup/_tests/test_tools.py new file mode 100644 index 0000000..b385bc6 --- /dev/null +++ b/exceptiongroup/_tests/test_tools.py @@ -0,0 +1,117 @@ +from exceptiongroup import ExceptionGroup, split + + +def raise_error(err): + raise err + + +def raise_error_from_another(out_err, another_err): + # use try..except approache so out_error have meaningful + # __context__, __cause__ attribute. + try: + raise another_err + except Exception as e: + raise out_err from e + + +def test_split_for_none_exception(): + matched, unmatched = split(RuntimeError, None) + assert matched is None + assert unmatched is None + + +def test_split_when_all_exception_matched(): + group = ExceptionGroup( + "Many Errors", + [RuntimeError("Runtime Error1"), RuntimeError("Runtime Error2")], + ["Runtime Error1", "Runtime Error2"] + ) + matched, unmatched = split(RuntimeError, group) + assert matched is group + assert unmatched is None + + +def test_split_when_all_exception_unmatched(): + group = ExceptionGroup( + "Many Errors", + [RuntimeError("Runtime Error1"), RuntimeError("Runtime Error2")], + ["Runtime Error1", "Runtime Error2"] + ) + matched, unmatched = split(ValueError, group) + assert matched is None + assert unmatched is group + + +def test_split_when_contains_matched_and_unmatched(): + error1 = RuntimeError("Runtime Error1") + error2 = ValueError("Value Error2") + group = ExceptionGroup( + "Many Errors", + [error1, error2], + ['Runtime Error1', 'Value Error2'] + ) + matched, unmatched = split(RuntimeError, group) + assert isinstance(matched, ExceptionGroup) + assert isinstance(unmatched, ExceptionGroup) + assert matched.exceptions == [error1] + assert matched.message == "Many Errors" + assert matched.sources == ['Runtime Error1'] + assert unmatched.exceptions == [error2] + assert unmatched.message == "Many Errors" + assert unmatched.sources == ['Value Error2'] + + +def test_split_with_predicate(): + def _match(err): + return str(err) != 'skip' + + error1 = RuntimeError("skip") + error2 = RuntimeError("Runtime Error") + group = ExceptionGroup( + "Many Errors", + [error1, error2], + ['skip', 'Runtime Error'] + ) + matched, unmatched = split(RuntimeError, group, match=_match) + assert matched.exceptions == [error2] + assert unmatched.exceptions == [error1] + + +def test_split_with_single_exception(): + err = RuntimeError("Error") + matched, unmatched = split(RuntimeError, err) + assert matched is err + assert unmatched is None + + matched, unmatched = split(ValueError, err) + assert matched is None + assert unmatched is err + + +def test_split_and_check_attributes_same(): + try: + raise_error(RuntimeError("RuntimeError")) + except Exception as e: + run_error = e + + try: + raise_error(ValueError("ValueError")) + except Exception as e: + val_error = e + + group = ExceptionGroup( + "ErrorGroup", [run_error, val_error], ["RuntimeError", "ValueError"] + ) + # go and check __traceback__, __cause__ attributes + try: + raise_error_from_another(group, RuntimeError("Cause")) + except BaseException as e: + new_group = e + + matched, unmatched = split(RuntimeError, group) + assert matched.__traceback__ is new_group.__traceback__ + assert matched.__cause__ is new_group.__cause__ + assert matched.__context__ is new_group.__context__ + assert unmatched.__traceback__ is new_group.__traceback__ + assert unmatched.__cause__ is new_group.__cause__ + assert unmatched.__context__ is new_group.__context__ diff --git a/exceptiongroup/_tools.py b/exceptiongroup/_tools.py index 5ab124c..3922c81 100644 --- a/exceptiongroup/_tools.py +++ b/exceptiongroup/_tools.py @@ -2,9 +2,22 @@ # Core primitives for working with ExceptionGroups ################################################################ +import copy from . import ExceptionGroup + def split(exc_type, exc, *, match=None): + """ splits the exception into one half (matched) representing all the parts of + the exception that match the predicate, and another half (not matched) + representing all the parts that don't match. + + Args: + exc_type (type of exception): The exception type we use to split. + exc (BaseException): Exception object we want to split. + match (None or func): predicate function to restict the split process, + if the argument is not None, only exceptions with match(exception) + will go into matched part. + """ if exc is None: return None, None elif isinstance(exc, ExceptionGroup): @@ -13,9 +26,9 @@ def split(exc_type, exc, *, match=None): rests = [] rest_notes = [] for subexc, note in zip(exc.exceptions, exc.sources): - match, rest = ExceptionGroup.split(exc_type, subexc, match=match) - if match is not None: - matches.append(match) + matched, rest = split(exc_type, subexc, match=match) + if matched is not None: + matches.append(matched) match_notes.append(note) if rest is not None: rests.append(rest) @@ -128,5 +141,18 @@ def __exit__(self, etype, exc, tb): exceptiongroup_catch_exc.__context__ = saved_context -def catch(cls, exc_type, handler, match=None): +def catch(exc_type, handler, match=None): + """Return a context manager that catches and re-throws exception. + after running :meth:`handle` on them. + + Args: + exc_type: An exception type or A tuple of exception type that need + to be handled by ``handler``. Exceptions which doesn't belong to + exc_type or doesn't match the predicate will not be handled by + ``handler``. + handler: the handler to handle exception which match exc_type and + predicate. + match: when the match is not None, ``handler`` will only handle when + match(exc) is True + """ return Catcher(exc_type, handler, match)