Skip to content

Commit

Permalink
Merge pull request #647 from OpenFreeEnergy/get_simsteps_changes
Browse files Browse the repository at this point in the history
Change get_simsteps function to handle NVT equilibration in plain MD protocol
  • Loading branch information
IAlibay authored Nov 22, 2023
2 parents 9d3ccbf + c45f85a commit c24ebb7
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 64 deletions.
10 changes: 7 additions & 3 deletions openfe/protocols/openmm_afe/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,9 +746,13 @@ def _run_simulation(
# Get the relevant simulation steps
mc_steps = settings['integrator_settings'].n_steps.m

equil_steps, prod_steps = settings_validation.get_simsteps(
equil_length=settings['simulation_settings'].equilibration_length,
prod_length=settings['simulation_settings'].production_length,
equil_steps = settings_validation.get_simsteps(
sim_length=settings['simulation_settings'].equilibration_length,
timestep=settings['integrator_settings'].timestep,
mc_steps=mc_steps,
)
prod_steps = settings_validation.get_simsteps(
sim_length=settings['simulation_settings'].production_length,
timestep=settings['integrator_settings'].timestep,
mc_steps=mc_steps,
)
Expand Down
11 changes: 7 additions & 4 deletions openfe/protocols/openmm_rfe/equil_rfe_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,10 +637,13 @@ def run(self, *, dry=False, verbose=True,
settings_validation.validate_timestep(
forcefield_settings.hydrogen_mass, timestep
)
equil_steps, prod_steps = settings_validation.get_simsteps(
equil_length=sim_settings.equilibration_length,
prod_length=sim_settings.production_length,
timestep=timestep, mc_steps=mc_steps
equil_steps = settings_validation.get_simsteps(
sim_length=sim_settings.equilibration_length,
timestep=timestep, mc_steps=mc_steps,
)
prod_steps = settings_validation.get_simsteps(
sim_length=sim_settings.production_length,
timestep=timestep, mc_steps=mc_steps,
)

solvent_comp, protein_comp, small_mols = system_validation.get_components(stateA)
Expand Down
42 changes: 16 additions & 26 deletions openfe/protocols/openmm_utils/settings_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,47 +37,37 @@ def validate_timestep(hmass: float, timestep: unit.Quantity):
raise ValueError(errmsg)


def get_simsteps(equil_length: unit.Quantity, prod_length: unit.Quantity,
timestep: unit.Quantity, mc_steps: int) -> Tuple[int, int]:
def get_simsteps(sim_length: unit.Quantity,
timestep: unit.Quantity, mc_steps: int) -> int:
"""
Gets and validates the number of equilibration and production steps.
Gets and validates the number of simulation steps.
Parameters
----------
equil_length : unit.Quantity
Simulation equilibration length.
prod_length : unit.Quantity
Simulation production length.
sim_length : unit.Quantity
Simulation length.
timestep : unit.Quantity
Integration timestep.
mc_steps : int
Number of integration timesteps between MCMC moves.
Returns
-------
equil_steps : int
The number of equilibration timesteps.
prod_steps : int
The number of production timesteps.
sim_steps : int
The number of simulation timesteps.
"""

equil_time = round(equil_length.to('attosecond').m)
prod_time = round(prod_length.to('attosecond').m)
sim_time = round(sim_length.to('attosecond').m)
ts = round(timestep.to('attosecond').m)

equil_steps, mod = divmod(equil_time, ts)
sim_steps, mod = divmod(sim_time, ts)
if mod != 0:
raise ValueError("Equilibration time not divisible by timestep")
prod_steps, mod = divmod(prod_time, ts)
if mod != 0:
raise ValueError("Production time not divisible by timestep")
raise ValueError("Simulation time not divisible by timestep")

for var in [("Equilibration", equil_steps, equil_time),
("Production", prod_steps, prod_time)]:
if (var[1] % mc_steps) != 0:
errmsg = (f"{var[0]} time {var[2]/1000000} ps should contain a "
"number of steps divisible by the number of integrator "
f"timesteps between MC moves {mc_steps}")
raise ValueError(errmsg)
if (sim_steps % mc_steps) != 0:
errmsg = (f"Simulation time {sim_time/1000000} ps should contain a "
"number of steps divisible by the number of integrator "
f"timesteps between MC moves {mc_steps}")
raise ValueError(errmsg)

return equil_steps, prod_steps
return sim_steps
49 changes: 18 additions & 31 deletions openfe/tests/protocols/test_openmmutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,43 +27,30 @@ def test_validate_timestep():
settings_validation.validate_timestep(2.0, 4.0 * unit.femtoseconds)


@pytest.mark.parametrize('e,p,ts,mc,es,ps', [
[1 * unit.nanoseconds, 5 * unit.nanoseconds, 4 * unit.femtoseconds,
250, 250000, 1250000],
[1 * unit.picoseconds, 1 * unit.picoseconds, 2 * unit.femtoseconds,
250, 500, 500],
@pytest.mark.parametrize('s,ts,mc,es', [
[5 * unit.nanoseconds, 4 * unit.femtoseconds, 250, 1250000],
[1 * unit.nanoseconds, 4 * unit.femtoseconds, 250, 250000],
[1 * unit.picoseconds, 2 * unit.femtoseconds, 250, 500],
])
def test_get_simsteps(e, p, ts, mc, es, ps):
equil_steps, prod_steps = settings_validation.get_simsteps(e, p, ts, mc)
def test_get_simsteps(s, ts, mc, es):
sim_steps = settings_validation.get_simsteps(s, ts, mc)

assert equil_steps == es
assert prod_steps == ps
assert sim_steps == es


@pytest.mark.parametrize('nametype, timelengths', [
['Equilibration', [1.003 * unit.picoseconds, 1 * unit.picoseconds]],
['Production', [1 * unit.picoseconds, 1.003 * unit.picoseconds]],
])
def test_get_simsteps_indivisible_simtime(nametype, timelengths):
errmsg = f"{nametype} time not divisible by timestep"
def test_get_simsteps_indivisible_simtime():
errmsg = "Simulation time not divisible by timestep"
timelength = 1.003 * unit.picosecond
with pytest.raises(ValueError, match=errmsg):
settings_validation.get_simsteps(
timelengths[0],
timelengths[1],
2 * unit.femtoseconds,
100)
settings_validation.get_simsteps(timelength, 2 * unit.femtoseconds, 100)


@pytest.mark.parametrize('nametype, timelengths', [
['Equilibration', [1 * unit.picoseconds, 10 * unit.picoseconds]],
['Production', [10 * unit.picoseconds, 1 * unit.picoseconds]],
])
def test_mc_indivisible(nametype, timelengths):
errmsg = f"{nametype} time 1.0 ps should contain"
def test_mc_indivisible():
errmsg = "Simulation time 1.0 ps should contain"
timelength = 1 * unit.picoseconds
with pytest.raises(ValueError, match=errmsg):
settings_validation.get_simsteps(
timelengths[0], timelengths[1],
2 * unit.femtoseconds, 1000)
timelength, 2 * unit.femtoseconds, 1000)


def test_get_alchemical_components(benzene_modifications,
Expand All @@ -90,7 +77,7 @@ def test_get_alchemical_components(benzene_modifications,

def test_duplicate_chemical_components(benzene_modifications):
stateA = openfe.ChemicalSystem({'A': benzene_modifications['toluene'],
'B': benzene_modifications['toluene'],})
'B': benzene_modifications['toluene'], })
stateB = openfe.ChemicalSystem({'A': benzene_modifications['toluene']})

errmsg = "state A components B:"
Expand Down Expand Up @@ -139,7 +126,7 @@ def test_multiple_proteins(T4_protein_component):
def test_get_components_gas(benzene_modifications):

state = openfe.ChemicalSystem({'A': benzene_modifications['benzene'],
'B': benzene_modifications['toluene'],})
'B': benzene_modifications['toluene'], })

s, p, mols = system_validation.get_components(state)

Expand All @@ -152,7 +139,7 @@ def test_components_solvent(benzene_modifications):

state = openfe.ChemicalSystem({'S': openfe.SolventComponent(),
'A': benzene_modifications['benzene'],
'B': benzene_modifications['toluene'],})
'B': benzene_modifications['toluene'], })

s, p, mols = system_validation.get_components(state)

Expand Down

0 comments on commit c24ebb7

Please sign in to comment.