From b9b66feab4e7af047aa3827b2227df6241610a5c Mon Sep 17 00:00:00 2001 From: Marko Ristin Date: Sat, 21 Sep 2024 07:43:43 +0200 Subject: [PATCH] Fixed invariants leak between related classes (#297) With pull request #292, we allowed users to specify the events which trigger the invariant checks for each individual invariant. This introduced a bug where invariants added to a child class were also added to a parent class. In this patch, we fixed the issue. Fixes #295. --- icontract/_checkers.py | 2 +- icontract/_metaclass.py | 38 ++++++++-- tests/test_inheritance_invariant.py | 111 ++++++++++++++++++++++++++++ 3 files changed, 142 insertions(+), 9 deletions(-) diff --git a/icontract/_checkers.py b/icontract/_checkers.py index efb445f..5e18a67 100644 --- a/icontract/_checkers.py +++ b/icontract/_checkers.py @@ -1213,7 +1213,7 @@ def add_invariant_checks(cls: ClassT) -> None: # We need to ignore __repr__ to prevent endless loops when generating error messages. # We also need to ignore __getattribute__ since pretty much any operation on the instance # will result in an endless loop. - if name in ["__new__", "__repr__", "__getattribute__"]: + if name in ("__new__", "__repr__", "__getattribute__"): continue if name == "__init__": diff --git a/icontract/_metaclass.py b/icontract/_metaclass.py index 623d933..a835f03 100644 --- a/icontract/_metaclass.py +++ b/icontract/_metaclass.py @@ -23,23 +23,38 @@ def _collapse_invariants( - bases: List[type], namespace: MutableMapping[str, Any] + bases: List[type], namespace: MutableMapping[str, Any], invariants_dunder: str ) -> None: - """Collect invariants from the bases and merge them with the invariants in the namespace.""" + """ + Collect invariants from the bases and merge them with the invariants in the namespace. + + We do not only collapse ``__invariants__`` class property, but we also need to collapse + the filtered ``__invariants_on_call__`` and ``__invariants_on_setattr__``, as they are + sub-lists of the ``__invariants__``. + """ + assert invariants_dunder in ( + "__invariants__", + "__invariants_on_call__", + "__invariants_on_setattr__", + ), "Unexpected invariants_dunder: {!r}".format(invariants_dunder) + + # region Invariants invariants = [] # type: List[Contract] # Add invariants of the bases for base in bases: - if hasattr(base, "__invariants__"): - invariants.extend(getattr(base, "__invariants__")) + if hasattr(base, invariants_dunder): + invariants.extend(getattr(base, invariants_dunder)) # Add invariants in the current namespace - if "__invariants__" in namespace: - invariants.extend(namespace["__invariants__"]) + if invariants_dunder in namespace: + invariants.extend(namespace[invariants_dunder]) # Change the final invariants in the namespace if invariants: - namespace["__invariants__"] = invariants + namespace[invariants_dunder] = invariants + + # endregion def _collapse_preconditions( @@ -323,7 +338,14 @@ def _dbc_decorate_namespace( Instance methods are simply replaced with the decorated function/ Properties, class methods and static methods are overridden with new instances of ``property``, ``classmethod`` and ``staticmethod``, respectively. """ - _collapse_invariants(bases=bases, namespace=namespace) + for invariant_dunder in ( + "__invariants__", + "__invariants_on_call__", + "__invariants_on_setattr__", + ): + _collapse_invariants( + bases=bases, namespace=namespace, invariants_dunder=invariant_dunder + ) for key, value in namespace.items(): if inspect.isfunction(value) or isinstance(value, (staticmethod, classmethod)): diff --git a/tests/test_inheritance_invariant.py b/tests/test_inheritance_invariant.py index 9672d6c..4c9b48c 100644 --- a/tests/test_inheritance_invariant.py +++ b/tests/test_inheritance_invariant.py @@ -51,6 +51,117 @@ def some_func(self) -> int: "Invariant is expected to run before and after the method call.", ) + def test_level_1_inheritance_of_invariants_does_not_leak_to_parents(self) -> None: + # NOTE (mristin): + # This is a regression test for: + # https://github.com/Parquery/icontract/issues/295 + # + # The invariants added to a child class were unexpectedly leaked back to + # the parent class. + + @icontract.invariant(lambda: True) + class Base(icontract.DBC): + def do_something(self) -> None: + pass + + def __repr__(self) -> str: + return "instance of {}".format(self.__class__.__name__) + + @icontract.invariant(lambda: False) + class Derived(Base): + pass + + Base() + + # NOTE (mristin): + # This produced an unexpected violation error. + Base().do_something() + + had_violation_error = False + try: + Derived() + except icontract.ViolationError: + had_violation_error = True + + assert had_violation_error + + def test_level_2_inheritance_of_invariants_does_not_leak_to_parents(self) -> None: + # NOTE (mristin): + # This is a regression test for: + # https://github.com/Parquery/icontract/issues/295 + # + # The invariants added to a child class were unexpectedly leaked back to + # the parent class. + + @icontract.invariant(lambda: True) + class Base(icontract.DBC): + def do_something(self) -> None: + pass + + def __repr__(self) -> str: + return "instance of {}".format(self.__class__.__name__) + + @icontract.invariant(lambda: True) + class Derived(Base): + pass + + @icontract.invariant(lambda: False) + class DerivedDerived(Base): + pass + + Base() + Base().do_something() + Derived() + + had_violation_error = False + try: + DerivedDerived() + except icontract.ViolationError: + had_violation_error = True + + assert had_violation_error + + # noinspection PyUnusedLocal + def test_level_3_inheritance_of_invariants_does_not_leak_to_parents(self) -> None: + # NOTE (mristin): + # This is a regression test for: + # https://github.com/Parquery/icontract/issues/295 + # + # The invariants added to a child class were unexpectedly leaked back to + # the parent class. + + class A(icontract.DBC): + def do_something(self) -> None: + pass + + def __repr__(self) -> str: + return "instance of {}".format(self.__class__.__name__) + + @icontract.invariant(lambda: True) + class B(A): + pass + + # NOTE (mristin): + # CFalse should not in any way influence A, B and CTrue, but it did due to + # a bug. + @icontract.invariant(lambda: False) + class CFalse(B): # pylint: disable=unused-variable + pass + + @icontract.invariant(lambda: True) + class CTrue(B): + pass + + A() + + CTrue() + + A().do_something() + + # NOTE (mristin): + # This produced an unexpected violation error. + CTrue().do_something() + class TestViolation(unittest.TestCase): def test_inherited(self) -> None: