Skip to content

Commit

Permalink
Allowed to enforce invariants on attribute setting
Browse files Browse the repository at this point in the history
Originally, we had enforced invariants only at calls to "normal"
methods, and excluded ``__setattr__`` since it is usually too expensive
to verify invariants whenever setting an attribute.

However, there are use cases where the users prefer to incur to
computational overhead for correctness. To that end, we introduced the
feature in this patch to steer when the invariants are enforced (at
method calls, on setting attributes, or in both situations).

Fixes #291.
  • Loading branch information
mristin committed Jul 5, 2024
1 parent 8edda0d commit 9212d87
Show file tree
Hide file tree
Showing 10 changed files with 437 additions and 66 deletions.
4 changes: 4 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,7 @@ DBC
ViolationError
--------------
.. autoclass:: ViolationError

InvariantCheckEvent
-------------------
.. autoclass:: InvariantCheckEvent
85 changes: 68 additions & 17 deletions docs/source/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -113,23 +113,36 @@ Invariants
----------
Invariants are special contracts associated with an instance of a class. An invariant should hold *after* initialization
and *before* and *after* a call to any public instance method. The invariants are the pivotal element of
design-by-contract: they allow you to formally define properties of a data structures that you know will be maintained
design-by-contract: they allow you to formally define properties of data structures that you know will be maintained
throughout the life time of *every* instance.

We consider the following methods to be "public":

* All methods not prefixed with ``_``
* All magic methods (prefix ``__`` and suffix ``__``)

Class methods (marked with ``@classmethod`` or special dunders such as ``__new__``) can not observe the invariant
since they are not associated with an instance of the class.
Class methods (marked with ``@classmethod`` or special dunders such as ``__new__``) can not observe the invariants
since they are not associated with an instance of the class. We also exempt ``__getattribute__`` method from observing
the invariants since these functions alter the state of the instance and thus can not be considered "public".
We exempt ``__repr__`` method as well to prevent endless loops when generating error messages.
At runtime, many icontract-specific dunder attributes (such as ``__invariants__``) need to be accessed, so the method
``__getattribute__`` can not be decorated lest we end up in an endless recursion.

We exempt ``__getattribute__``, ``__setattr__`` and ``__delattr__`` methods from observing the invariant since
these functions alter the state of the instance and thus can not be considered "public".
By default, we do not enforce the invariants on calls to ``__setattr__`` as that is usually
prohibitively expensive in terms of computation for most use cases. However, there is a parameter
``check_on`` to an :class:`invariant` which allows you to steer in a more fine-grained manner when the invariant should
be enforced.

We also exempt ``__repr__`` method to prevent endless loops when generating error messages.
The default value of ``check_on`` is set to :attr:`InvariantCheckEvent.CALL`, meaning that we check
the invariants only in the calls to the methods *excluding* ``__setattr__``. If you want to check
the invariants *only* on ``__setattr__`` and excluding *any* other method, set it to :attr:`InvariantCheckEvent.SETATTR`.
The combinations is also possible; to check invariants on method calls *including* ``__setattr__``, set ``check_on`` to
:attr:`InvariantCheckEvent.CALL` ``|`` :attr:`InvariantCheckEvent.SETATTR`.

The icontract invariants are implemented as class decorators.
.. note::

The property getters and setters are considered "normal" methods. If you want to check the invariants at property
getters and/or setters, make sure to include :attr:`InvariantCheckEvent.CALL` in ``check_on``.

The following examples show various cases when an invariant is breached.

Expand Down Expand Up @@ -229,6 +242,45 @@ After the invocation of a magic method:
self was an instance of SomeClass
self.x was -1
Enforcing the invariants on the method calls *including* ``__setattr__``:

.. code-block:: python
>>> @icontract.invariant(
... lambda self: self.x > 0,
... check_on=(
... icontract.InvariantCheckEvent.CALL
... | icontract.InvariantCheckEvent.SETATTR
... )
... )
... class SomeClass:
... def __init__(self) -> None:
... self.x = 100
...
... def do_something_bad(self) -> None:
... self.x = -1
...
... def __repr__(self) -> str:
... return "an instance of SomeClass"
...
>>> some_instance = SomeClass()
>>> some_instance.do_something_bad()
Traceback (most recent call last):
...
icontract.errors.ViolationError: File <doctest usage.rst[26]>, line 1 in <module>:
self.x > 0:
self was an instance of SomeClass
self.x was -1
>>> another_instance = SomeClass()
>>> another_instance.x = -1
Traceback (most recent call last):
...
icontract.errors.ViolationError: File <doctest usage.rst[26]>, line 1 in <module>:
self.x > 0:
self was an instance of SomeClass
self.x was -1
Snapshots (a.k.a "old" argument values)
---------------------------------------
Usual postconditions can not verify the state transitions of the function's argument values. For example, it is
Expand Down Expand Up @@ -261,7 +313,7 @@ Here is an example that uses snapshots to check that a value was appended to the
>>> some_func(lst=[1, 2], value=3)
Traceback (most recent call last):
...
icontract.errors.ViolationError: File <doctest usage.rst[28]>, line 2 in <module>:
icontract.errors.ViolationError: File <doctest usage.rst[33]>, line 2 in <module>:
lst == OLD.lst + [value]:
OLD was a bunch of OLD values
OLD.lst was [1, 2]
Expand All @@ -285,7 +337,7 @@ The following example shows how you can name the snapshot:
>>> some_func(lst=[1, 2], value=3)
Traceback (most recent call last):
...
icontract.errors.ViolationError: File <doctest usage.rst[32]>, line 2 in <module>:
icontract.errors.ViolationError: File <doctest usage.rst[37]>, line 2 in <module>:
len(lst) == OLD.len_lst + 1:
OLD was a bunch of OLD values
OLD.len_lst was 2
Expand All @@ -311,7 +363,7 @@ The next code snippet shows how you can combine multiple arguments of a function
>>> some_func(lst_a=[1, 2], lst_b=[3, 4]) # doctest: +ELLIPSIS
Traceback (most recent call last):
...
icontract.errors.ViolationError: File <doctest usage.rst[36]>, line ... in <module>:
icontract.errors.ViolationError: File <doctest usage.rst[...]>, line ... in <module>:
set(lst_a).union(lst_b) == OLD.union:
OLD was a bunch of OLD values
OLD.union was {1, 2, 3, 4}
Expand Down Expand Up @@ -394,7 +446,7 @@ The following example shows an abstract parent class and a child class that inhe
>>> some_b.func(y=0)
Traceback (most recent call last):
...
icontract.errors.ViolationError: File <doctest usage.rst[40]>, line 7 in A:
icontract.errors.ViolationError: File <doctest usage.rst[45]>, line 7 in A:
result < y:
result was 1
self was an instance of B
Expand All @@ -405,7 +457,7 @@ The following example shows an abstract parent class and a child class that inhe
>>> another_b.break_parent_invariant()
Traceback (most recent call last):
...
icontract.errors.ViolationError: File <doctest usage.rst[40]>, line 1 in <module>:
icontract.errors.ViolationError: File <doctest usage.rst[45]>, line 1 in <module>:
self.x > 0:
self was an instance of B
self.x was -1
Expand All @@ -415,7 +467,7 @@ The following example shows an abstract parent class and a child class that inhe
>>> yet_another_b.break_my_invariant()
Traceback (most recent call last):
...
icontract.errors.ViolationError: File <doctest usage.rst[41]>, line 1 in <module>:
icontract.errors.ViolationError: File <doctest usage.rst[46]>, line 1 in <module>:
self.x < 100:
self was an instance of B
self.x was 101
Expand Down Expand Up @@ -453,7 +505,7 @@ The following example shows how preconditions are weakened:
>>> b.func(x=5)
Traceback (most recent call last):
...
icontract.errors.ViolationError: File <doctest usage.rst[49]>, line 2 in B:
icontract.errors.ViolationError: File <doctest usage.rst[54]>, line 2 in B:
x % 3 == 0:
self was an instance of B
x was 5
Expand Down Expand Up @@ -484,7 +536,7 @@ The example below illustrates how snapshots are inherited:
>>> b.func(lst=[1, 2], value=3)
Traceback (most recent call last):
...
icontract.errors.ViolationError: File <doctest usage.rst[54]>, line 4 in A:
icontract.errors.ViolationError: File <doctest usage.rst[59]>, line 4 in A:
len(lst) == len(OLD.lst) + 1:
OLD was a bunch of OLD values
OLD.lst was [1, 2]
Expand All @@ -495,7 +547,6 @@ The example below illustrates how snapshots are inherited:
self was an instance of B
value was 3
Toggling Contracts
------------------
By default, the contract checks (including the snapshots) are always performed at run-time. To disable them, run the
Expand Down Expand Up @@ -607,7 +658,7 @@ Here is an example of the error given as a subclass of `BaseException`_:
>>> some_func(x=0)
Traceback (most recent call last):
...
ValueError: File <doctest usage.rst[62]>, line 1 in <module>:
ValueError: File <doctest usage.rst[67]>, line 1 in <module>:
x > 0: x was 0
Here is an example of the error given as an instance of a `BaseException`_:
Expand Down
2 changes: 2 additions & 0 deletions icontract/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,5 @@
import icontract.errors

ViolationError = icontract.errors.ViolationError

InvariantCheckEvent = icontract._types.InvariantCheckEvent
100 changes: 68 additions & 32 deletions icontract/_checkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@

import icontract._represent
from icontract._globals import CallableT, ClassT
from icontract._types import Contract, Snapshot
from icontract._types import Contract, Snapshot, InvariantCheckEvent
from icontract.errors import ViolationError


# pylint does not play with typing.Mapping.
# pylint: disable=unsubscriptable-object
# pylint: disable=raising-bad-type
Expand Down Expand Up @@ -965,8 +966,8 @@ def wrapper(*args, **kwargs): # type: ignore
"""Pass the arguments to __new__ and check invariants on the result."""
instance = new_func(*args, **kwargs)

for contract in instance.__class__.__invariants__:
_assert_invariant(contract=contract, instance=instance)
for invariant in instance.__class__.__invariants__:
_assert_invariant(contract=invariant, instance=instance)

return instance

Expand All @@ -979,7 +980,7 @@ def wrapper(*args, **kwargs): # type: ignore

def _decorate_with_invariants(func: CallableT, is_init: bool) -> CallableT:
"""
Decorate the function ``func`` of the class ``cls`` with invariant checks.
Decorate the method ``func`` with invariant checks.
If the function has been already decorated with invariant checks, the function returns immediately.
Expand Down Expand Up @@ -1025,8 +1026,8 @@ def wrapper(*args, **kwargs): # type: ignore
try:
result = func(*args, **kwargs)

for contract in instance.__class__.__invariants__:
_assert_invariant(contract=contract, instance=instance)
for invariant in instance.__class__.__invariants__:
_assert_invariant(contract=invariant, instance=instance)

return result
finally:
Expand Down Expand Up @@ -1060,6 +1061,12 @@ async def wrapper(*args, **kwargs): # type: ignore
).format(func, param_names, args, kwargs)
) from err

invariants = (
instance.__class__.__invariants_on_setattr__
if func.__name__ == "__setattr__"
else instance.__class__.__invariants_on_call__
)

# We need to create a new in-progress set if it is None as the ``ContextVar`` does not accept
# a factory function for the default argument. If we didn't do this, and simply set an empty
# set as the default, ``ContextVar`` would always point to the same set by copying the default
Expand All @@ -1080,13 +1087,13 @@ async def wrapper(*args, **kwargs): # type: ignore

# ExitStack is not used here due to performance.
try:
for contract in instance.__class__.__invariants__:
_assert_invariant(contract=contract, instance=instance)
for invariant in invariants:
_assert_invariant(contract=invariant, instance=instance)

result = await func(*args, **kwargs)

for contract in instance.__class__.__invariants__:
_assert_invariant(contract=contract, instance=instance)
for invariant in invariants:
_assert_invariant(contract=invariant, instance=instance)

return result
finally:
Expand All @@ -1108,6 +1115,12 @@ def wrapper(*args, **kwargs): # type: ignore
).format(func, param_names, args, kwargs)
) from err

invariants = (
instance.__class__.__invariants_on_setattr__
if func.__name__ == "__setattr__"
else instance.__class__.__invariants_on_call__
)

# The following dunder indicates whether another invariant is currently being checked. If so,
# we need to suspend any further invariant check to avoid endless recursion.

Expand All @@ -1129,13 +1142,13 @@ def wrapper(*args, **kwargs): # type: ignore

# ExitStack is not used here due to performance.
try:
for contract in instance.__class__.__invariants__:
_assert_invariant(contract=contract, instance=instance)
for invariant in invariants:
_assert_invariant(contract=invariant, instance=instance)

result = func(*args, **kwargs)

for contract in instance.__class__.__invariants__:
_assert_invariant(contract=contract, instance=instance)
for invariant in invariants:
_assert_invariant(contract=invariant, instance=instance)

return result
finally:
Expand Down Expand Up @@ -1169,25 +1182,38 @@ def _already_decorated_with_invariants(func: CallableT) -> bool:
def add_invariant_checks(cls: ClassT) -> None:
"""Decorate each of the class functions with invariant checks if not already decorated."""
# Candidates for the decoration as list of (name, dir() value)
init_name_func = None # type: Optional[Tuple[str, Callable[..., None]]]
init_func = None # type: Optional[Callable[..., None]]
names_funcs = [] # type: List[Tuple[str, Callable[..., None]]]
names_properties = [] # type: List[Tuple[str, property]]

# Filter out entries in the directory which are certainly not candidates for decoration.
# As we continuously decorate the class with invariants, we never definitely know
# whether this decoration is the last one. Hence, we can only retrieve the list
# of invariants decorated *thus far*. As we only add one invariant at the time,
# we only need to check for the last invariant.
assert cls.__invariants__ is not None, ( # type: ignore
"Expected to set ``__invariants__`` in the invariant decorator before "
"the call to {}".format(add_invariant_checks.__name__)
)
assert len(cls.__invariants__) > 0, ( # type: ignore
"Expected at least one invariant in the ``__invariants__`` since we expect "
"to push the latest invariant in the invariant decorator before the call to "
"{}".format(add_invariant_checks.__name__)
)
last_invariant = cls.__invariants__[-1] # type: ignore
assert isinstance(last_invariant, icontract._types.Invariant)

# Filter out entries in the directory which are certainly not candidates for decoration
# regarding the ``last_invariant``. Note that the functions which are already decorated
# will not be re-decorated, so that this loop runs in O( dir(cls) * len(invariants) ),
# but with a negligible constant.
for name in dir(cls):
value = getattr(cls, name)

# __new__ is a special class method (though not marked properly with @classmethod!).
# We need to ignore __repr__ to prevent endless loops when generating error messages.
# __getattribute__, __setattr__ and __delattr__ are too invasive and alter the state of the instance.
# Hence we don't consider them "public".
if name in [
"__new__",
"__repr__",
"__getattribute__",
"__setattr__",
"__delattr__",
]:
# We also need to ignore __getattribute__ since pretty much any operation on the instance
# will result in an endless loop.
if name in ["__new__", "__repr__", "__getattribute__"]:
continue

if name == "__init__":
Expand All @@ -1197,7 +1223,19 @@ def add_invariant_checks(cls: ClassT) -> None:
type(value)
)

init_name_func = (name, value)
init_func = value
continue

if (
name != "__setattr__"
and InvariantCheckEvent.CALL not in last_invariant.check_on
):
continue

if (
name == "__setattr__"
and InvariantCheckEvent.SETATTR not in last_invariant.check_on
):
continue

if (
Expand Down Expand Up @@ -1234,18 +1272,16 @@ def add_invariant_checks(cls: ClassT) -> None:
)
)

if init_name_func:
name, func = init_name_func

if init_func:
# We have to distinguish this special case which is used by named
# tuples and possibly other optimized data structures.
# In those cases, we have to wrap __new__ instead of __init__.
if func == object.__init__ and hasattr(cls, "__new__"):
if init_func == object.__init__ and hasattr(cls, "__new__"):
new_func = getattr(cls, "__new__")
setattr(cls, "__new__", _decorate_new_with_invariants(new_func))
else:
wrapper = _decorate_with_invariants(func=func, is_init=True)
setattr(cls, name, wrapper)
wrapper = _decorate_with_invariants(func=init_func, is_init=True)
setattr(cls, init_func.__name__, wrapper)

for name, func in names_funcs:
wrapper = _decorate_with_invariants(func=func, is_init=False)
Expand Down
Loading

0 comments on commit 9212d87

Please sign in to comment.