Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Renaming and moving some of the protocol settings #689

Merged
merged 45 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from 43 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
189815e
Initial commit to rename protocol settings
hannahbaumann Jan 17, 2024
7a4266e
more settings changes
hannahbaumann Jan 18, 2024
1ed1ea9
Transfer changes in AlchemicalSamplerSettings to different protocols
hannahbaumann Jan 18, 2024
a16a91e
Change n_steps in protocols
hannahbaumann Jan 18, 2024
190eec9
Propagate changes in OutputSettings to protocols and tests
hannahbaumann Jan 18, 2024
eb238e3
change validator
hannahbaumann Jan 18, 2024
d499356
fixes new settings
hannahbaumann Jan 18, 2024
768b4b3
small changes
hannahbaumann Jan 19, 2024
d7145d8
changes protocol_repeats
hannahbaumann Jan 19, 2024
ff84000
add values to validators
hannahbaumann Jan 19, 2024
2566438
small fix
hannahbaumann Jan 19, 2024
b22c51a
change doc string minimum iterations
hannahbaumann Jan 19, 2024
2dd0faa
more renaming online_analysis to real_time and pep8 fixes
hannahbaumann Jan 19, 2024
b39e6e3
for MD protocol use mcmc_steps=1 since only used for testing the time…
hannahbaumann Jan 19, 2024
e55ce2c
changes softcore_LJ
hannahbaumann Jan 24, 2024
4d08133
Fixes protocol_repeats
hannahbaumann Jan 25, 2024
1d6f9c7
Fix get_integrator
hannahbaumann Jan 25, 2024
b5bca0d
small fixes
hannahbaumann Jan 25, 2024
a3f72d1
Move remove_com to IntegratorSettings
hannahbaumann Jan 25, 2024
ad8170f
more small fixes
hannahbaumann Jan 25, 2024
aef9156
Plain MD protocol move n_repeats from RepeatSettings to ProtocolSetti…
hannahbaumann Jan 25, 2024
c83a3c9
small fixes
hannahbaumann Jan 25, 2024
c20f5b6
Larger settings changes, suggestions
hannahbaumann Jan 26, 2024
573f271
Change unit early_termination_target_error to kcal/mol
hannahbaumann Jan 30, 2024
990667c
Merge branch 'main' into rename_settings
richardjgowers Feb 1, 2024
4a73dff
updates for new Settings
richardjgowers Feb 2, 2024
55981bb
document dev scripts for generating jsons
richardjgowers Feb 5, 2024
60a42bc
fixup usage of 'sampler_settings' in AFE base
richardjgowers Feb 5, 2024
c8acd48
clean up some imports
richardjgowers Feb 5, 2024
3aa4581
use convert_steps_per_iteration in afe base
richardjgowers Feb 5, 2024
238bedf
works on validators and unit conversions
richardjgowers Feb 5, 2024
940424f
fixup on unit conversion
richardjgowers Feb 5, 2024
e0fabfe
remove commented out SamplerSettings block
richardjgowers Feb 5, 2024
6db0c53
added tests for settings_validation conversions
richardjgowers Feb 5, 2024
8bc76b1
got kT conversion upside down
richardjgowers Feb 5, 2024
0b858df
small fixups for new settings
richardjgowers Feb 5, 2024
0fa3995
update result jsons for tests
richardjgowers Feb 5, 2024
03c4a17
shorten AHFE time in dev scripts
richardjgowers Feb 5, 2024
8457b78
Merge branch 'main' into rename_settings
richardjgowers Feb 5, 2024
5f25fd4
shorten AHFE time in dev scripts
richardjgowers Feb 5, 2024
fcf5735
update openfecli test files for new settings
richardjgowers Feb 6, 2024
cc3e22e
revert changes to gen-serialized-results.py
richardjgowers Feb 6, 2024
ae5374e
update away from nose test style
richardjgowers Feb 6, 2024
8c84c7b
fixup LambdaSettings docstrings
richardjgowers Feb 6, 2024
2a948cc
remove doc stubs for removed settings
richardjgowers Feb 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions devtools/data/gen-serialized-results.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
"""
Dev script to generate some result jsons that are used for testing

Generates
- AHFEProtocol_json_results.gz
- used in afe_solvation_json fixture
- RHFEProtocol_json_results.gz
- used in rfe_transformation_json fixture
- MDProtocol_json_results.gz
- used in md_json fixture
"""
import gzip
import json
import logging
Expand Down Expand Up @@ -60,7 +71,7 @@ def generate_md_json(smc):
settings.simulation_settings.equilibration_length_nvt = 0.01 * unit.nanosecond
settings.simulation_settings.equilibration_length = 0.01 * unit.nanosecond
settings.simulation_settings.production_length = 0.01 * unit.nanosecond
settings.system_settings.nonbonded_method = "nocutoff"
settings.forcefield_settings.nonbonded_method = "nocutoff"
protocol = PlainMDProtocol(settings=settings)
system = openfe.ChemicalSystem({"ligand": smc})
dag = protocol.create(stateA=system, stateB=system, mapping=None)
Expand All @@ -80,9 +91,11 @@ def generate_ahfe_json(smc):
settings.lambda_settings.lambda_vdw = [0.0, 0.0, 0.0, 0.0, 0.0, 0.12, 0.24,
0.36, 0.48, 0.6, 0.7, 0.77, 0.85,
1.0]
settings.alchemsampler_settings.n_repeats = 3
settings.alchemsampler_settings.n_replicas = 14
settings.alchemsampler_settings.online_analysis_target_error = 0.2 * unit.boltzmann_constant * unit.kelvin
settings.protocol_repeats = 3
settings.solvent_simulation_settings.n_replicas = 14
settings.vacuum_simulation_settings.n_replicas = 14
settings.solvent_simulation_settings.early_termination_target_error = 0.12 * unit.kilocalorie_per_mole
settings.vacuum_simulation_settings.early_termination_target_error = 0.12 * unit.kilocalorie_per_mole
settings.vacuum_engine_settings.compute_platform = 'CPU'
settings.solvent_engine_settings.compute_platform = 'CUDA'

Expand All @@ -103,7 +116,7 @@ def generate_rfe_json(smcA, smcB):
settings = RelativeHybridTopologyProtocol.default_settings()
settings.simulation_settings.equilibration_length = 10 * unit.picosecond
settings.simulation_settings.production_length = 250 * unit.picosecond
settings.system_settings.nonbonded_method = "nocutoff"
settings.forcefield_settings.nonbonded_method = "nocutoff"
protocol = RelativeHybridTopologyProtocol(settings=settings)

a_smcB = align_mol_shape(smcB, ref_mol=smcA)
Expand Down
112 changes: 67 additions & 45 deletions openfe/protocols/openmm_afe/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@
)
from openfe.protocols.openmm_afe.equil_afe_settings import (
SolvationSettings,
AlchemicalSamplerSettings, OpenMMEngineSettings,
IntegratorSettings, SimulationSettings, LambdaSettings,
MultiStateSimulationSettings, OpenMMEngineSettings,
IntegratorSettings, LambdaSettings, OutputSettings,
ThermoSettings,
)
from openfe.protocols.openmm_rfe._rfe_utils import compute
from ..openmm_utils import (
Expand Down Expand Up @@ -230,14 +231,13 @@
Get a dictionary with the following entries:
* forcefield_settings : OpenMMSystemGeneratorFFSettings
* thermo_settings : ThermoSettings
* system_settings : SystemSettings
* solvation_settings : SolvationSettings
* alchemical_settings : AlchemicalSettings
* lambda_settings : LambdaSettings
* sampler_settings : AlchemicalSamplerSettings
* engine_settings : OpenMMEngineSettings
* integrator_settings : IntegratorSettings
* simulation_settings : SimulationSettings
* simulation_settings : MultiStateSimulationSettings
* output_settings: OutputSettings

Settings may change depending on what type of simulation you are
running. Cherry pick them and return them to be available later on.
Expand Down Expand Up @@ -270,15 +270,14 @@
system_generator : openmmforcefields.generator.SystemGenerator
System Generator to parameterise this unit.
"""
ffcache = settings['simulation_settings'].forcefield_cache
ffcache = settings['output_settings'].forcefield_cache
if ffcache is not None:
ffcache = self.shared_basepath / ffcache

system_generator = system_creation.get_system_generator(
forcefield_settings=settings['forcefield_settings'],
thermo_settings=settings['thermo_settings'],
integrator_settings=settings['integrator_settings'],
system_settings=settings['system_settings'],
thermo_settings=settings['thermo_settings'],
cache=ffcache,
has_solvent=solvent_comp is not None,
)
Expand Down Expand Up @@ -537,7 +536,7 @@
self,
topology: app.Topology,
positions: openmm.unit.Quantity,
simulation_settings: SimulationSettings,
output_settings: OutputSettings,
) -> multistate.MultiStateReporter:
"""
Get a MultistateReporter for the simulation you are running.
Expand All @@ -546,8 +545,8 @@
----------
topology : app.Topology
A Topology of the system being created.
simulation_settings : SimulationSettings
Settings for the simulation.
output_settings: OutputSettings
Output settings for the simulations

Returns
-------
Expand All @@ -557,16 +556,16 @@
mdt_top = mdt.Topology.from_openmm(topology)

selection_indices = mdt_top.select(
simulation_settings.output_indices
output_settings.output_indices
)

nc = self.shared_basepath / simulation_settings.output_filename
chk = simulation_settings.checkpoint_storage
nc = self.shared_basepath / output_settings.output_filename
chk = output_settings.checkpoint_storage_filename

reporter = multistate.MultiStateReporter(
storage=nc,
analysis_particle_indices=selection_indices,
checkpoint_interval=simulation_settings.checkpoint_interval.m,
checkpoint_interval=output_settings.checkpoint_interval.m,
checkpoint_storage=chk,
)

Expand All @@ -577,7 +576,7 @@
mdt_top.subset(selection_indices),
)
traj.save_pdb(
self.shared_basepath / simulation_settings.output_structure
self.shared_basepath / output_settings.output_structure
)

return reporter
Expand Down Expand Up @@ -614,38 +613,45 @@

return energy_context_cache, sampler_context_cache

@staticmethod
def _get_integrator(
self,
integrator_settings: IntegratorSettings
integrator_settings: IntegratorSettings,
simulation_settings: MultiStateSimulationSettings
) -> openmmtools.mcmc.LangevinDynamicsMove:
"""
Return a LangevinDynamicsMove integrator

Parameters
----------
integrator_settings : IntegratorSettings
simulation_settings : MultiStateSimulationSettings

Returns
-------
integrator : openmmtools.mcmc.LangevinDynamicsMove
A configured integrator object.
"""
steps_per_iteration = settings_validation.convert_steps_per_iteration(
simulation_settings, integrator_settings
)

integrator = openmmtools.mcmc.LangevinDynamicsMove(
timestep=to_openmm(integrator_settings.timestep),
collision_rate=to_openmm(integrator_settings.collision_rate),
n_steps=integrator_settings.n_steps.m,
collision_rate=to_openmm(integrator_settings.langevin_collision_rate),
n_steps=steps_per_iteration,
reassign_velocities=integrator_settings.reassign_velocities,
n_restart_attempts=integrator_settings.n_restart_attempts,
constraint_tolerance=integrator_settings.constraint_tolerance,
)

return integrator

@staticmethod
def _get_sampler(
self,
integrator: openmmtools.mcmc.LangevinDynamicsMove,
reporter: openmmtools.multistate.MultiStateReporter,
sampler_settings: AlchemicalSamplerSettings,
simulation_settings: MultiStateSimulationSettings,
thermo_settings: ThermoSettings,
cmp_states: list[ThermodynamicState],
sampler_states: list[SamplerState],
energy_context_cache: openmmtools.cache.ContextCache,
Expand All @@ -660,8 +666,10 @@
The simulation integrator.
reporter : openmmtools.multistate.MultiStateReporter
The reporter to hook up to the sampler.
sampler_settings : AlchemicalSamplerSettings
simulation_settings : MultiStateSimulationSettings
Settings for the alchemical sampler.
thermo_settings : ThermoSettings
Thermodynamic settings
cmp_states : list[ThermodynamicState]
A list of thermodynamic states to sample.
sampler_states : list[SamplerState]
Expand All @@ -676,30 +684,37 @@
sampler : multistate.MultistateSampler
A sampler configured for the chosen sampling method.
"""
rta_its, rta_min_its = settings_validation.convert_real_time_analysis_iterations(
simulation_settings=simulation_settings,
)
et_target_err = settings_validation.convert_target_error_from_kcal_per_mole_to_kT(
thermo_settings.temperature,
simulation_settings.early_termination_target_error,
)

# Select the right sampler
# Note: doesn't need else, settings already validates choices
if sampler_settings.sampler_method.lower() == "repex":
if simulation_settings.sampler_method.lower() == "repex":
sampler = multistate.ReplicaExchangeSampler(
mcmc_moves=integrator,
online_analysis_interval=sampler_settings.online_analysis_interval,
online_analysis_target_error=sampler_settings.online_analysis_target_error.m,
online_analysis_minimum_iterations=sampler_settings.online_analysis_minimum_iterations
online_analysis_interval=rta_its,
online_analysis_target_error=et_target_err,
online_analysis_minimum_iterations=rta_min_its
)
elif sampler_settings.sampler_method.lower() == "sams":
elif simulation_settings.sampler_method.lower() == "sams":
sampler = multistate.SAMSSampler(
mcmc_moves=integrator,
online_analysis_interval=sampler_settings.online_analysis_interval,
online_analysis_minimum_iterations=sampler_settings.online_analysis_minimum_iterations,
flatness_criteria=sampler_settings.flatness_criteria,
gamma0=sampler_settings.gamma0,
online_analysis_interval=rta_its,
online_analysis_minimum_iterations=rta_min_its,
flatness_criteria=simulation_settings.sams_flatness_criteria,
gamma0=simulation_settings.sams_gamma0,
)
elif sampler_settings.sampler_method.lower() == 'independent':
elif simulation_settings.sampler_method.lower() == 'independent':
sampler = multistate.MultiStateSampler(
mcmc_moves=integrator,
online_analysis_interval=sampler_settings.online_analysis_interval,
online_analysis_target_error=sampler_settings.online_analysis_target_error.m,
online_analysis_minimum_iterations=sampler_settings.online_analysis_minimum_iterations
online_analysis_interval=rta_its,
online_analysis_target_error=et_target_err,
online_analysis_minimum_iterations=rta_min_its,
)

sampler.create(
Expand Down Expand Up @@ -741,7 +756,10 @@
if not a dry run.
"""
# Get the relevant simulation steps
mc_steps = settings['integrator_settings'].n_steps.m
mc_steps = settings_validation.convert_steps_per_iteration(
simulation_settings=settings['simulation_settings'],
integrator_settings=settings['integrator_settings'],
)

equil_steps = settings_validation.get_simsteps(
sim_length=settings['simulation_settings'].equilibration_length,
Expand Down Expand Up @@ -780,7 +798,7 @@

analyzer = multistate_analysis.MultistateEquilFEAnalysis(
reporter,
sampling_method=settings['sampler_settings'].sampler_method.lower(),
sampling_method=settings['simulation_settings'].sampler_method.lower(),
result_units=unit.kilocalorie_per_mole
)
analyzer.plot(filepath=self.shared_basepath, filename_prefix="")
Expand All @@ -793,8 +811,8 @@
reporter.close()

# clean up the reporter file
fns = [self.shared_basepath / settings['simulation_settings'].output_filename,
self.shared_basepath / settings['simulation_settings'].checkpoint_storage]
fns = [self.shared_basepath / settings['output_settings'].output_filename,
self.shared_basepath / settings['output_settings'].checkpoint_storage_filename]
for fn in fns:
os.remove(fn)

Expand Down Expand Up @@ -878,7 +896,7 @@
# 11. Create the multistate reporter & create PDB
reporter = self._get_reporter(
omm_topology, positions,
settings['simulation_settings'],
settings['output_settings'],
)

# Wrap in try/finally to avoid memory leak issues
Expand All @@ -889,11 +907,15 @@
)

# 13. Get integrator
integrator = self._get_integrator(settings['integrator_settings'])
integrator = self._get_integrator(
settings['integrator_settings'],
settings['simulation_settings'],
)

# 14. Get sampler
sampler = self._get_sampler(
integrator, reporter, settings['sampler_settings'],
integrator, reporter, settings['simulation_settings'],
settings['thermo_settings'],
cmp_states, sampler_states,
energy_ctx_cache, sampler_ctx_cache
)
Expand Down Expand Up @@ -925,8 +947,8 @@
del integrator, sampler

if not dry:
nc = self.shared_basepath / settings['simulation_settings'].output_filename
chk = settings['simulation_settings'].checkpoint_storage
nc = self.shared_basepath / settings['output_settings'].output_filename
chk = settings['output_settings'].checkpoint_storage_filename

Check warning on line 951 in openfe/protocols/openmm_afe/base.py

View check run for this annotation

Codecov / codecov/patch

openfe/protocols/openmm_afe/base.py#L950-L951

Added lines #L950 - L951 were not covered by tests
return {
'nc': nc,
'last_checkpoint': chk,
Expand Down
Loading
Loading