Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix stage checking #278

Merged
merged 7 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ Fixes
* fix rcoulomb in CHARMM energy minimization MDP template file (PR #210)
* fix ensemble.EnsembleAnalysis.check_groups_from_common_ensemble (#212)
* updated versioneer (#285)
* fix that simulation stages cannot be restarted after error (#272)


2022-01-03 0.8.0
Expand Down
29 changes: 21 additions & 8 deletions mdpow/restart.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def incomplete(self):
def incomplete(self, stage):
if not stage in self.stages:
raise ValueError(
"can only assign a registered stage from %(stages)r" % vars(self)
"Can only assign a registered stage from %(stages)r" % vars(self)
)
self.__incomplete = stage

Expand All @@ -143,7 +143,7 @@ def completed(self, stage):

def start(self, stage):
"""Record that *stage* is starting."""
if self.current is not None:
if self.current is not None and self.current != stage:
errmsg = (
"Cannot start stage %s because previously started stage %s "
"has not been completed." % (stage, self.current)
Expand All @@ -157,7 +157,16 @@ def has_completed(self, stage):
return stage in self.history

def has_not_completed(self, stage):
"""Returns ``True`` if the *stage* had been started but not completed yet."""
"""Returns ``True`` if the *stage* had been started but not completed yet.

This is subtly different from ``not`` :func:`has_completed` in
that two things have to be true:

1. No stage is active (which is the case when a restart is attempted).
2. The `stage` has not been completed previously (i.e.,
:func:`has_completed` returns ``False``)

"""
return self.current is None and not self.has_completed(stage)

def clear(self):
Expand Down Expand Up @@ -190,13 +199,14 @@ def __init__(self, *args, **kwargs):
len(self.journal.history)
except AttributeError:
self.journal = Journal(self.protocols)
super(Journalled, self).__init__(*args, **kwargs)
super().__init__(*args, **kwargs)

def get_protocol(self, protocol):
"""Return method for *protocol*.

- If *protocol* is a real method of the class then the method is
returned.
returned. This method should implement its own use of
:meth:`Journal.start` and :meth:`Journal.completed`.

- If *protocol* is a registered protocol name but no method of
the name exists (i.e. *protocol* is a "dummy protocol") then
Expand All @@ -205,9 +215,12 @@ def get_protocol(self, protocol):

.. function:: dummy_protocol(func, *args, **kwargs)

Runs *func* with the arguments and keywords between calls
to :meth:`Journal.start` and :meth:`Journal.completed`,
with the stage set to *protocol*.
Runs *func* with the arguments and keywords between calls to
:meth:`Journal.start` and :meth:`Journal.completed`, with the
stage set to *protocol*.

The function should return ``True`` on success and ``False`` on
failure.

- Raises a :exc:`ValueError` if the *protocol* is not
registered (i.e. not found in :attr:`Journalled.protocols`).
Expand Down
175 changes: 175 additions & 0 deletions mdpow/tests/test_journals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import pytest

from mdpow import restart


@pytest.fixture
def journal():
return restart.Journal(["pre", "main", "post"])


class TestJournal:
def test_full_sequence(self, journal):
journal.start("pre")
assert journal.current == "pre"
journal.completed("pre")

journal.start("main")
assert journal.current == "main"
journal.completed("main")

journal.start("post")
assert journal.current == "post"
journal.completed("post")

def test_set_wrong_stage_ValueError(self, journal):
with pytest.raises(ValueError, match="Can only assign a registered stage"):
journal.start("BEGIN !")

def test_JournalSequenceError_no_completion(self, journal):
with pytest.raises(restart.JournalSequenceError, match="Cannot start stage"):
journal.start("pre")
assert journal.current == "pre"

journal.start("main")

@pytest.mark.xfail
def test_JournalSequenceError_skip_stage(self, journal):
# Currently allows skipping a stage and does not enforce ALL previous
# stages to have completed.
with pytest.raises(restart.JournalSequenceError, match="Cannot start stage"):
journal.start("pre")
assert journal.current == "pre"
journal.completed("pre")

journal.start("post")

def test_start_idempotent(self, journal):
# test that start() can be called multiple time (#278)
journal.start("pre")
journal.start("pre")
assert journal.current == "pre"

def test_incomplete_known_stage(self, journal):
journal.incomplete = "main"
assert journal.incomplete == "main"

def test_incomplete_unknown_stage_ValueError(self, journal):
with pytest.raises(ValueError, match="Can only assign a registered stage from"):
journal.incomplete = "BEGIN !"

def test_clear(self, journal):
journal.start("pre")
journal.completed("pre")
journal.start("main")
# manually setting incomplete
journal.incomplete = journal.current

assert journal.current == "main"
assert journal.incomplete == journal.current

journal.clear()
assert journal.current is None
assert journal.incomplete is None

def test_history(self, journal):
journal.start("pre")
journal.completed("pre")
journal.start("main")
journal.completed("main")
journal.start("post")

# completed stages
assert journal.history == ["pre", "main"]

def test_history_del(self, journal):
journal.start("pre")
journal.completed("pre")
journal.start("main")
journal.completed("main")
assert journal.history

del journal.history
assert journal.history == []

def test_has_completed(self, journal):
journal.start("pre")
journal.completed("pre")

assert journal.has_completed("pre")
assert not journal.has_completed("main")

def test_has_not_completed(self, journal):
journal.start("pre")
journal.completed("pre")
journal.start("main")
# simulate crash/restart
del journal.current

assert journal.has_not_completed("main")
assert not journal.has_not_completed("pre")


# need a real class so that it can be pickled later
class JournalledMemory(restart.Journalled):
# divide is a dummy protocol
protocols = ["divide", "multiply"]

def __init__(self):
self.memory = 1
super().__init__()

def multiply(self, x):
self.journal.start("multiply")
self.memory *= x
self.journal.completed("multiply")


@pytest.fixture
def journalled():
return JournalledMemory()


class TestJournalled:
@staticmethod
def divide(m, x):
return m.memory / x

def test_get_protocol_of_class(self, journalled):
f = journalled.get_protocol("multiply")
f(10)
assert journalled.memory == 10
assert journalled.journal.has_completed("multiply")

def test_get_protocol_dummy(self, journalled):
dummy_protocol = journalled.get_protocol("divide")
result = dummy_protocol(self.divide, journalled, 10)

assert result == 1 / 10
assert journalled.journal.has_completed("divide")

def test_get_protocol_dummy_incomplete(self, journalled):
dummy_protocol = journalled.get_protocol("divide")
with pytest.raises(ZeroDivisionError):
result = dummy_protocol(self.divide, journalled, 0)
assert not journalled.journal.has_completed("divide")

def test_save_load(self, tmp_path):
# instantiate a class that can be pickled (without pytest magic)
journalled = JournalledMemory()
f = journalled.get_protocol("multiply")
f(10)
assert journalled.memory == 10

pickle = tmp_path / "memory.pkl"
journalled.save(pickle)

assert pickle.exists()

# change instance
f(99)
assert journalled.memory == 10 * 99

# reload previous state
journalled.load(pickle)
assert journalled.memory == 10
Loading