diff --git a/.gitignore b/.gitignore index 68b6c9d50..b8552aeeb 100644 --- a/.gitignore +++ b/.gitignore @@ -161,5 +161,8 @@ cython_debug/ # vim *.swp +# vscode +.vscode/ + # Example notebooks -docs/ExampleNotebooks/ +docs/ExampleNotebooks/ \ No newline at end of file diff --git a/docs/reference/api/index.rst b/docs/reference/api/index.rst index 5d0df4f7e..2ed7668ff 100644 --- a/docs/reference/api/index.rst +++ b/docs/reference/api/index.rst @@ -12,3 +12,4 @@ OpenFE API Reference alchemical_network_planning defining_and_executing_simulations openmm_rfe + openmm_solvation_afe diff --git a/docs/reference/api/openmm_solvation_afe.rst b/docs/reference/api/openmm_solvation_afe.rst new file mode 100644 index 000000000..f3bc78eb4 --- /dev/null +++ b/docs/reference/api/openmm_solvation_afe.rst @@ -0,0 +1,146 @@ +OpenMM Absolute Solvation Free Energy Protocol +============================================== + +This section provides details about the OpenMM Absolute Solvation Free Energy Protocol +implemented in OpenFE. + +Protocol API specification +-------------------------- + +.. module:: openfe.protocols.openmm_afe.equil_solvation_afe_method + +.. autosummary:: + :nosignatures: + :toctree: generated/ + + AbsoluteSolvationProtocol + AbsoluteSolvationProtocolResult + +Protocol Settings +----------------- + + +Below are the settings which can be tweaked in the protocol. The default settings (accessed using :meth:`AbsoluteSolvationProtocol.default_settings`) will automatically populate settings which we have found to be useful for running solvation free energy calculations. There will however be some cases (such as when calculating difficult to converge systems) where you will need to tweak some of the following settings. + +.. autopydantic_model:: AbsoluteSolvationSettings + :model-show-json: False + :model-show-field-summary: False + :model-show-config-member: False + :model-show-config-summary: False + :model-show-validator-members: False + :model-show-validator-summary: False + :field-list-validators: False + :inherited-members: SettingsBaseModel + :exclude-members: get_defaults + :member-order: bysource + :noindex: + +.. module:: openfe.protocols.openmm_afe.equil_afe_settings + +.. autopydantic_model:: OpenMMSystemGeneratorFFSettings + :model-show-json: False + :model-show-field-summary: False + :model-show-config-member: False + :model-show-config-summary: False + :model-show-validator-members: False + :model-show-validator-summary: False + :field-list-validators: False + :inherited-members: SettingsBaseModel + :member-order: bysource + :noindex: + +.. autopydantic_model:: ThermoSettings + :model-show-json: False + :model-show-field-summary: False + :model-show-config-member: False + :model-show-config-summary: False + :model-show-validator-members: False + :model-show-validator-summary: False + :field-list-validators: False + :inherited-members: SettingsBaseModel + :member-order: bysource + :noindex: + +.. autopydantic_model:: AlchemicalSamplerSettings + :model-show-json: False + :model-show-field-summary: False + :model-show-config-member: False + :model-show-config-summary: False + :model-show-validator-members: False + :model-show-validator-summary: False + :field-list-validators: False + :inherited-members: SettingsBaseModel + :member-order: bysource + :noindex: + +.. autopydantic_model:: AlchemicalSettings + :model-show-json: False + :model-show-field-summary: False + :model-show-config-member: False + :model-show-config-summary: False + :model-show-validator-members: False + :model-show-validator-summary: False + :field-list-validators: False + :inherited-members: SettingsBaseModel + :member-order: bysource + :noindex: + +.. autopydantic_model:: OpenMMEngineSettings + :model-show-json: False + :model-show-field-summary: False + :model-show-config-member: False + :model-show-config-summary: False + :model-show-validator-members: False + :model-show-validator-summary: False + :field-list-validators: False + :inherited-members: SettingsBaseModel + :member-order: bysource + :noindex: + +.. autopydantic_model:: IntegratorSettings + :model-show-json: False + :model-show-field-summary: False + :model-show-config-member: False + :model-show-config-summary: False + :model-show-validator-members: False + :model-show-validator-summary: False + :field-list-validators: False + :inherited-members: SettingsBaseModel + :member-order: bysource + :noindex: + +.. autopydantic_model:: SimulationSettings + :model-show-json: False + :model-show-field-summary: False + :model-show-config-member: False + :model-show-config-summary: False + :model-show-validator-members: False + :model-show-validator-summary: False + :field-list-validators: False + :inherited-members: SettingsBaseModel + :member-order: bysource + :noindex: + +.. autopydantic_model:: SolvationSettings + :model-show-json: False + :model-show-field-summary: False + :model-show-config-member: False + :model-show-config-summary: False + :model-show-validator-members: False + :model-show-validator-summary: False + :field-list-validators: False + :inherited-members: SettingsBaseModel + :member-order: bysource + :noindex: + +.. autopydantic_model:: SystemSettings + :model-show-json: False + :model-show-field-summary: False + :model-show-config-member: False + :model-show-config-summary: False + :model-show-validator-members: False + :model-show-validator-summary: False + :field-list-validators: False + :inherited-members: SettingsBaseModel + :member-order: bysource + :noindex: diff --git a/openfe/protocols/openmm_afe/__init__.py b/openfe/protocols/openmm_afe/__init__.py new file mode 100644 index 000000000..40c10e8aa --- /dev/null +++ b/openfe/protocols/openmm_afe/__init__.py @@ -0,0 +1,22 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +Run absolute free energy calculations using OpenMM and OpenMMTools. + +""" + +from .equil_solvation_afe_method import ( + AbsoluteSolvationProtocol, + AbsoluteSolvationSettings, + AbsoluteSolvationProtocolResult, + AbsoluteSolvationVacuumUnit, + AbsoluteSolvationSolventUnit, +) + +__all__ = [ + "AbsoluteSolvationProtocol", + "AbsoluteSolvationSettings", + "AbsoluteSolvationProtocolResult", + "AbsoluteVacuumUnit", + "AbsoluteSolventUnit", +] diff --git a/openfe/protocols/openmm_afe/base.py b/openfe/protocols/openmm_afe/base.py new file mode 100644 index 000000000..e66508771 --- /dev/null +++ b/openfe/protocols/openmm_afe/base.py @@ -0,0 +1,925 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +"""OpenMM Equilibrium AFE Protocol base classes +=============================================== + +Base classes for the equilibrium OpenMM absolute free energy ProtocolUnits. + +Thist mostly implements BaseAbsoluteUnit whose methods can be +overriden to define different types of alchemical transformations. + +TODO +---- +* Add in all the AlchemicalFactory and AlchemicalRegion kwargs + as settings. +* Allow for a more flexible setting of Lambda regions. +""" +from __future__ import annotations + +import abc +import os +import logging + +import gufe +from gufe.components import Component +import numpy as np +import numpy.typing as npt +import openmm +from openff.units import unit +from openff.units.openmm import from_openmm, to_openmm, ensure_quantity +from openmmtools import multistate +from openmmtools.states import (SamplerState, + ThermodynamicState, + create_thermodynamic_state_protocol,) +from openmmtools.alchemy import (AlchemicalRegion, AbsoluteAlchemicalFactory, + AlchemicalState,) +from typing import Dict, List, Optional +from openmm import app +from openmm import unit as omm_unit +from openmmforcefields.generators import SystemGenerator +import pathlib +from typing import Any +import openmmtools +import mdtraj as mdt + +from gufe import ( + settings, ChemicalSystem, SmallMoleculeComponent, + ProteinComponent, SolventComponent +) +from openfe.protocols.openmm_utils.omm_settings import ( + SettingsBaseModel, +) +from openfe.protocols.openmm_afe.equil_afe_settings import ( + SolvationSettings, + AlchemicalSamplerSettings, OpenMMEngineSettings, + IntegratorSettings, SimulationSettings, +) +from openfe.protocols.openmm_rfe._rfe_utils import compute +from ..openmm_utils import ( + settings_validation, system_creation, + multistate_analysis +) + +logger = logging.getLogger(__name__) + + +class BaseAbsoluteUnit(gufe.ProtocolUnit): + """ + Base class for ligand absolute free energy transformations. + """ + def __init__(self, *, + stateA: ChemicalSystem, + stateB: ChemicalSystem, + settings: settings.Settings, + alchemical_components: Dict[str, List[str]], + generation: int = 0, + repeat_id: int = 0, + name: Optional[str] = None,): + """ + Parameters + ---------- + stateA : ChemicalSystem + ChemicalSystem containing the components defining the state at + lambda 0. + stateB : ChemicalSystem + ChemicalSystem containing the components defining the state at + lambda 1. + settings : gufe.settings.Setings + Settings for the Absolute Tranformation Protocol. This can be + constructed by calling the + :class:`AbsoluteTransformProtocol.get_default_settings` method + to get a default set of settings. + name : str, optional + Human-readable identifier for this Unit + repeat_id : int, optional + Identifier for which repeat (aka replica/clone) this Unit is, + default 0 + generation : int, optional + Generation counter which keeps track of how many times this repeat + has been extended, default 0. + """ + super().__init__( + name=name, + stateA=stateA, + stateB=stateB, + settings=settings, + alchemical_components=alchemical_components, + repeat_id=repeat_id, + generation=generation, + ) + + @staticmethod + def _get_alchemical_indices(omm_top: openmm.Topology, + comp_resids: Dict[str, npt.NDArray], + alchem_comps: Dict[str, List[Component]] + ) -> List[int]: + """ + Get a list of atom indices for all the alchemical species + + Parameters + ---------- + omm_top : openmm.Topology + Topology of OpenMM System. + comp_resids : Dict[str, npt.NDArray] + A dictionary of residues for each component in the System. + alchem_comps : Dict[str, List[Component]] + A dictionary of alchemical components for each end state. + + Return + ------ + atom_ids : List[int] + A list of atom indices for the alchemical species + """ + + # concatenate a list of residue indexes for all alchemical components + residxs = np.concatenate( + [comp_resids[key] for key in alchem_comps['stateA']] + ) + + # get the alchemicical atom ids + atom_ids = [] + + for r in omm_top.residues(): + if r.index in residxs: + atom_ids.extend([at.index for at in r.atoms()]) + + return atom_ids + + @staticmethod + def _pre_minimize(system: openmm.System, + positions: omm_unit.Quantity) -> npt.NDArray: + """ + Short CPU minization of System to avoid GPU NaNs + + Parameters + ---------- + system : openmm.System + An OpenMM System to minimize. + positionns : openmm.unit.Quantity + Initial positions for the system. + + Returns + ------- + minimized_positions : npt.NDArray + Minimized positions + """ + integrator = openmm.VerletIntegrator(0.001) + context = openmm.Context( + system, integrator, + openmm.Platform.getPlatformByName('CPU'), + ) + context.setPositions(positions) + # Do a quick 100 steps minimization, usually avoids NaNs + openmm.LocalEnergyMinimizer.minimize( + context, maxIterations=100 + ) + state = context.getState(getPositions=True) + minimized_positions = state.getPositions(asNumpy=True) + return minimized_positions + + def _prepare( + self, verbose: bool, + scratch_basepath: Optional[pathlib.Path], + shared_basepath: Optional[pathlib.Path], + ): + """ + Set basepaths and do some initial logging. + + Parameters + ---------- + verbose : bool + Verbose output of the simulation progress. Output is provided via + INFO level logging. + basepath : Optional[pathlib.Path] + Optional base path to write files to. + """ + self.verbose = verbose + + if self.verbose: + self.logger.info("setting up alchemical system") + + # set basepaths + def _set_optional_path(basepath): + if basepath is None: + return pathlib.Path('.') + return basepath + + self.scratch_basepath = _set_optional_path(scratch_basepath) + self.shared_basepath = _set_optional_path(shared_basepath) + + @abc.abstractmethod + def _get_components(self): + """ + Get the relevant components to create the alchemical system with. + + Note + ---- + Must be implemented in the child class. + """ + ... + + @abc.abstractmethod + def _handle_settings(self): + """ + Get a dictionary with the following entries: + * forcefield_settings : OpenMMSystemGeneratorFFSettings + * thermo_settings : ThermoSettings + * system_settings : SystemSettings + * solvation_settings : SolvationSettings + * alchemical_settings : AlchemicalSettings + * sampler_settings : AlchemicalSamplerSettings + * engine_settings : OpenMMEngineSettings + * integrator_settings : IntegratorSettings + * simulation_settings : SimulationSettings + + Settings may change depending on what type of simulation you are + running. Cherry pick them and return them to be available later on. + + This method should also add various validation checks as necessary. + + Note + ---- + Must be implemented in the child class. + """ + ... + + def _get_system_generator( + self, settings: dict[str, SettingsBaseModel], + solvent_comp: Optional[SolventComponent] + ) -> SystemGenerator: + """ + Get a system generator through the system creation + utilities + + Parameters + ---------- + settings : dict[str, SettingsBaseModel] + A dictionary of settings object for the unit. + solvent_comp : Optional[SolventComponent] + The solvent component of this system, if there is one. + + Returns + ------- + system_generator : openmmforcefields.generator.SystemGenerator + System Generator to parameterise this unit. + """ + ffcache = settings['simulation_settings'].forcefield_cache + if ffcache is not None: + ffcache = self.shared_basepath / ffcache + + system_generator = system_creation.get_system_generator( + forcefield_settings=settings['forcefield_settings'], + thermo_settings=settings['thermo_settings'], + system_settings=settings['system_settings'], + cache=ffcache, + has_solvent=solvent_comp is not None, + ) + return system_generator + + def _get_modeller( + self, + protein_component: Optional[ProteinComponent], + solvent_component: Optional[SolventComponent], + smc_components: list[SmallMoleculeComponent], + system_generator: SystemGenerator, + solvation_settings: SolvationSettings + ) -> tuple[app.Modeller, dict[Component, npt.NDArray]]: + """ + Get an OpenMM Modeller object and a list of residue indices + for each component in the system. + + Parameters + ---------- + protein_component : Optional[ProteinComponent] + Protein Component, if it exists. + solvent_component : Optional[ProteinCompoinent] + Solvent Component, if it exists. + smc_components : list[SmallMoleculeComponents] + List of SmallMoleculeComponents to add. + system_generator : openmmforcefields.generator.SystemGenerator + System Generator to parameterise this unit. + solvation_settings : SolvationSettings + Settings detailing how to solvate the system. + + Returns + ------- + system_modeller : app.Modeller + OpenMM Modeller object generated from ProteinComponent and + OpenFF Molecules. + comp_resids : dict[Component, npt.NDArray] + Dictionary of residue indices for each component in system. + """ + if self.verbose: + self.logger.info("Parameterizing molecules") + + # force the creation of parameters for the small molecules + # this is necessary because we need to have the FF generated ahead + # of solvating the system. + # Note by default this is cached to ctx.shared/db.json which should + # reduce some of the costs. + for comp in smc_components: + offmol = comp.to_openff() + system_generator.create_system( + offmol.to_topology().to_openmm(), molecules=[offmol] + ) + + # get OpenMM modeller + dictionary of resids for each component + system_modeller, comp_resids = system_creation.get_omm_modeller( + protein_comp=protein_component, + solvent_comp=solvent_component, + small_mols=smc_components, + omm_forcefield=system_generator.forcefield, + solvent_settings=solvation_settings, + ) + + return system_modeller, comp_resids + + def _get_omm_objects( + self, + system_modeller: app.Modeller, + system_generator: SystemGenerator, + smc_components: list[SmallMoleculeComponent], + ) -> tuple[app.Topology, openmm.unit.Quantity, openmm.System]: + """ + Get the OpenMM Topology, Positions and System of the + parameterised system. + + Parameters + ---------- + system_modeller : app.Modeller + OpenMM Modeller object representing the system to be + parametrized. + system_generator : SystemGenerator + SystemGenerator object to create a System with. + smc_components : list[SmallMoleculeComponent] + A list of SmallMoleculeComponents to add to the system. + + Returns + ------- + topology : app.Topology + Topology object describing the parameterized system + system : openmm.System + An OpenMM System of the alchemical system. + positionns : openmm.unit.Quantity + Positions of the system. + """ + topology = system_modeller.getTopology() + # roundtrip positions to remove vec3 issues + positions = to_openmm(from_openmm(system_modeller.getPositions())) + system = system_generator.create_system( + system_modeller.topology, + molecules=[s.to_openff() for s in smc_components] + ) + return topology, system, positions + + def _get_lambda_schedule( + self, settings: dict[str, SettingsBaseModel] + ) -> dict[str, npt.NDArray]: + """ + Create the lambda schedule + + Parameters + ---------- + settings : dict[str, SettingsBaseModel] + Settings for the unit. + + Returns + ------- + lambdas : dict[str, npt.NDArray] + + TODO + ---- + * Augment this by using something akin to the RFE protocol's + LambdaProtocol + """ + lambdas = dict() + n_elec = settings['alchemical_settings'].lambda_elec_windows + n_vdw = settings['alchemical_settings'].lambda_vdw_windows + 1 + lambdas['lambda_electrostatics'] = np.concatenate( + [np.linspace(1, 0, n_elec), np.linspace(0, 0, n_vdw)[1:]] + ) + lambdas['lambda_sterics'] = np.concatenate( + [np.linspace(1, 1, n_elec), np.linspace(1, 0, n_vdw)[1:]] + ) + + n_replicas = settings['sampler_settings'].n_replicas + + if n_replicas != (len(lambdas['lambda_sterics'])): + errmsg = (f"Number of replicas {n_replicas} " + "does not equal the number of lambda windows ") + raise ValueError(errmsg) + + return lambdas + + def _add_restraints(self, system, topology, settings): + """ + Placeholder method to add restraints if necessary + """ + return + + def _get_alchemical_system( + self, + topology: app.Topology, + system: openmm.System, + comp_resids: dict[Component, npt.NDArray], + alchem_comps: dict[str, list[Component]] + ) -> tuple[AbsoluteAlchemicalFactory, openmm.System, list[int]]: + """ + Get an alchemically modified system and its associated factory + + Parameters + ---------- + topology : openmm.Topology + Topology of OpenMM System. + system : openmm.System + System to alchemically modify. + comp_resids : dict[str, npt.NDArray] + A dictionary of residues for each component in the System. + alchem_comps : dict[str, list[Component]] + A dictionary of alchemical components for each end state. + + + Returns + ------- + alchemical_factory : AbsoluteAlchemicalFactory + Factory for creating an alchemically modified system. + alchemical_system : openmm.System + Alchemically modified system + alchemical_indices : list[int] + A list of atom indices for the alchemically modified + species in the system. + + TODO + ---- + * Add support for all alchemical factory options + """ + alchemical_indices = self._get_alchemical_indices( + topology, comp_resids, alchem_comps + ) + + alchemical_region = AlchemicalRegion( + alchemical_atoms=alchemical_indices, + ) + + alchemical_factory = AbsoluteAlchemicalFactory() + alchemical_system = alchemical_factory.create_alchemical_system( + system, alchemical_region + ) + + return alchemical_factory, alchemical_system, alchemical_indices + + def _get_states( + self, + alchemical_system: openmm.System, + positions: openmm.unit.Quantity, + settings: dict[str, SettingsBaseModel], + lambdas: dict[str, npt.NDArray], + solvent_comp: Optional[SolventComponent], + ) -> tuple[list[SamplerState], list[ThermodynamicState]]: + """ + Get a list of sampler and thermodynmic states from an + input alchemical system. + + Parameters + ---------- + alchemical_system : openmm.System + Alchemical system to get states for. + positions : openmm.unit.Quantity + Positions of the alchemical system. + settings : dict[str, SettingsBaseModel] + A dictionary of settings for the protocol unit. + lambdas : dict[str, npt.NDArray] + A dictionary of lambda scales. + solvent_comp : Optional[SolventComponent] + The solvent component of the system, if there is one. + + Returns + ------- + sampler_states : list[SamplerState] + A list of SamplerStates for each replica in the system. + cmp_states : list[ThermodynamicState] + A list of ThermodynamicState for each replica in the system. + """ + alchemical_state = AlchemicalState.from_system(alchemical_system) + # Set up the system constants + temperature = settings['thermo_settings'].temperature + pressure = settings['thermo_settings'].pressure + constants = dict() + constants['temperature'] = ensure_quantity(temperature, 'openmm') + if solvent_comp is not None: + constants['pressure'] = ensure_quantity(pressure, 'openmm') + + cmp_states = create_thermodynamic_state_protocol( + alchemical_system, protocol=lambdas, + constants=constants, composable_states=[alchemical_state], + ) + + sampler_state = SamplerState(positions=positions) + if alchemical_system.usesPeriodicBoundaryConditions(): + box = alchemical_system.getDefaultPeriodicBoxVectors() + sampler_state.box_vectors = box + + sampler_states = [sampler_state for _ in cmp_states] + + return sampler_states, cmp_states + + def _get_reporter( + self, + topology: app.Topology, + positions: openmm.unit.Quantity, + simulation_settings: SimulationSettings, + ) -> multistate.MultiStateReporter: + """ + Get a MultistateReporter for the simulation you are running. + + Parameters + ---------- + topology : app.Topology + A Topology of the system being created. + simulation_settings : SimulationSettings + Settings for the simulation. + + Returns + ------- + reporter : multistate.MultiStateReporter + The reporter for the simulation. + """ + mdt_top = mdt.Topology.from_openmm(topology) + + selection_indices = mdt_top.select( + simulation_settings.output_indices + ) + + nc = self.shared_basepath / simulation_settings.output_filename + chk = self.shared_basepath / simulation_settings.checkpoint_storage + + reporter = multistate.MultiStateReporter( + storage=nc, + analysis_particle_indices=selection_indices, + checkpoint_interval=simulation_settings.checkpoint_interval.m, + checkpoint_storage=chk, + ) + + # Write out the structure's PDB whilst we're here + if len(selection_indices) > 0: + traj = mdt.Trajectory( + positions[selection_indices, :], + mdt_top.subset(selection_indices), + ) + traj.save_pdb( + self.shared_basepath / simulation_settings.output_structure + ) + + return reporter + + def _get_ctx_caches( + self, + engine_settings: OpenMMEngineSettings + ) -> tuple[openmmtools.cache.ContextCache, openmmtools.cache.ContextCache]: + """ + Set the context caches based on the chosen platform + + Parameters + ---------- + engine_settings : OpenMMEngineSettings, + + Returns + ------- + energy_context_cache : openmmtools.cache.ContextCache + The energy state context cache. + sampler_context_cache : openmmtools.cache.ContextCache + The sampler state context cache. + """ + platform = compute.get_openmm_platform( + engine_settings.compute_platform, + ) + + energy_context_cache = openmmtools.cache.ContextCache( + capacity=None, time_to_live=None, platform=platform, + ) + + sampler_context_cache = openmmtools.cache.ContextCache( + capacity=None, time_to_live=None, platform=platform, + ) + + return energy_context_cache, sampler_context_cache + + def _get_integrator( + self, + integrator_settings: IntegratorSettings + ) -> openmmtools.mcmc.LangevinDynamicsMove: + """ + Return a LangevinDynamicsMove integrator + + Parameters + ---------- + integrator_settings : IntegratorSettings + + Returns + ------- + integrator : openmmtools.mcmc.LangevinDynamicsMove + A configured integrator object. + """ + integrator = openmmtools.mcmc.LangevinDynamicsMove( + timestep=to_openmm(integrator_settings.timestep), + collision_rate=to_openmm(integrator_settings.collision_rate), + n_steps=integrator_settings.n_steps.m, + reassign_velocities=integrator_settings.reassign_velocities, + n_restart_attempts=integrator_settings.n_restart_attempts, + constraint_tolerance=integrator_settings.constraint_tolerance, + ) + + return integrator + + def _get_sampler( + self, + integrator: openmmtools.mcmc.LangevinDynamicsMove, + reporter: openmmtools.multistate.MultiStateReporter, + sampler_settings: AlchemicalSamplerSettings, + cmp_states: list[ThermodynamicState], + sampler_states: list[SamplerState], + energy_context_cache: openmmtools.cache.ContextCache, + sampler_context_cache: openmmtools.cache.ContextCache + ) -> multistate.MultiStateSampler: + """ + Get a sampler based on the equilibrium sampling method requested. + + Parameters + ---------- + integrator : openmmtools.mcmc.LangevinDynamicsMove + The simulation integrator. + reporter : openmmtools.multistate.MultiStateReporter + The reporter to hook up to the sampler. + sampler_settings : AlchemicalSamplerSettings + Settings for the alchemical sampler. + cmp_states : list[ThermodynamicState] + A list of thermodynamic states to sample. + sampler_states : list[SamplerState] + A list of sampler states. + energy_context_cache : openmmtools.cache.ContextCache + Context cache for the energy states. + sampler_context_cache : openmmtool.cache.ContextCache + Context cache for the sampler states. + + Returns + ------- + sampler : multistate.MultistateSampler + A sampler configured for the chosen sampling method. + """ + + # Select the right sampler + # Note: doesn't need else, settings already validates choices + if sampler_settings.sampler_method.lower() == "repex": + sampler = multistate.ReplicaExchangeSampler( + mcmc_moves=integrator, + online_analysis_interval=sampler_settings.online_analysis_interval, + online_analysis_target_error=sampler_settings.online_analysis_target_error.m, + online_analysis_minimum_iterations=sampler_settings.online_analysis_minimum_iterations + ) + elif sampler_settings.sampler_method.lower() == "sams": + sampler = multistate.SAMSSampler( + mcmc_moves=integrator, + online_analysis_interval=sampler_settings.online_analysis_interval, + online_analysis_minimum_iterations=sampler_settings.online_analysis_minimum_iterations, + flatness_criteria=sampler_settings.flatness_criteria, + gamma0=sampler_settings.gamma0, + ) + elif sampler_settings.sampler_method.lower() == 'independent': + sampler = multistate.MultiStateSampler( + mcmc_moves=integrator, + online_analysis_interval=sampler_settings.online_analysis_interval, + online_analysis_target_error=sampler_settings.online_analysis_target_error.m, + online_analysis_minimum_iterations=sampler_settings.online_analysis_minimum_iterations + ) + + sampler.create( + thermodynamic_states=cmp_states, + sampler_states=sampler_states, + storage=reporter + ) + + sampler.energy_context_cache = energy_context_cache + sampler.sampler_context_cache = sampler_context_cache + + return sampler + + def _run_simulation( + self, + sampler: multistate.MultiStateSampler, + reporter: multistate.MultiStateReporter, + settings: dict[str, SettingsBaseModel], + dry: bool + ): + """ + Run the simulation. + + Parameters + ---------- + sampler : multistate.MultiStateSampler + The sampler associated with the simulation to run. + reporter : multistate.MultiStateReporter + The reporter associated with the sampler. + settings : dict[str, SettingsBaseModel] + The dictionary of settings for the protocol. + dry : bool + Whether or not to dry run the simulation + + Returns + ------- + unit_results_dict : Optional[dict] + A dictionary containing all the free energy results, + if not a dry run. + """ + # Get the relevant simulation steps + mc_steps = settings['integrator_settings'].n_steps.m + + equil_steps, prod_steps = settings_validation.get_simsteps( + equil_length=settings['simulation_settings'].equilibration_length, + prod_length=settings['simulation_settings'].production_length, + timestep=settings['integrator_settings'].timestep, + mc_steps=mc_steps, + ) + + if not dry: # pragma: no-cover + # minimize + if self.verbose: + self.logger.info("minimizing systems") + + sampler.minimize( + max_iterations=settings['simulation_settings'].minimization_steps + ) + + # equilibrate + if self.verbose: + self.logger.info("equilibrating systems") + + sampler.equilibrate(int(equil_steps / mc_steps)) # type: ignore + + # production + if self.verbose: + self.logger.info("running production phase") + + sampler.extend(int(prod_steps / mc_steps)) # type: ignore + + if self.verbose: + self.logger.info("production phase complete") + + if self.verbose: + self.logger.info("post-simulation result analysis") + + analyzer = multistate_analysis.MultistateEquilFEAnalysis( + reporter, + sampling_method=settings['sampler_settings'].sampler_method.lower(), + result_units=unit.kilocalorie_per_mole + ) + analyzer.plot(filepath=self.shared_basepath, filename_prefix="") + analyzer.close() + + return analyzer.unit_results_dict + + else: + # close reporter when you're done, prevent file handle clashes + reporter.close() + + # clean up the reporter file + fns = [self.shared_basepath / settings['simulation_settings'].output_filename, + self.shared_basepath / settings['simulation_settings'].checkpoint_storage] + for fn in fns: + os.remove(fn) + + return None + + def run(self, dry=False, verbose=True, + scratch_basepath=None, shared_basepath=None) -> Dict[str, Any]: + """Run the absolute free energy calculation. + + Parameters + ---------- + dry : bool + Do a dry run of the calculation, creating all necessary alchemical + system components (topology, system, sampler, etc...) but without + running the simulation. + verbose : bool + Verbose output of the simulation progress. Output is provided via + INFO level logging. + basepath : Pathlike, optional + Where to run the calculation, defaults to current working directory + + Returns + ------- + dict + Outputs created in the basepath directory or the debug objects + (i.e. sampler) if ``dry==True``. + + Attributes + ---------- + solvent : Optional[SolventComponent] + SolventComponent to be applied to the system + protein : Optional[ProteinComponent] + ProteinComponent for the system + openff_mols : List[openff.Molecule] + List of OpenFF Molecule objects for each SmallMoleculeComponent in + the stateA ChemicalSystem + """ + # 0. Generaly preparation tasks + self._prepare(verbose, scratch_basepath, shared_basepath) + + # 1. Get components + alchem_comps, solv_comp, prot_comp, smc_comps = self._get_components() + + # 2. Get settings + settings = self._handle_settings() + + # 3. Get system generator + system_generator = self._get_system_generator(settings, solv_comp) + + # 4. Get modeller + system_modeller, comp_resids = self._get_modeller( + prot_comp, solv_comp, smc_comps, system_generator, + settings['solvation_settings'] + ) + + # 5. Get OpenMM topology, positions and system + omm_topology, omm_system, positions = self._get_omm_objects( + system_modeller, system_generator, smc_comps + ) + + # 6. Pre-minimize System (Test + Avoid NaNs) + positions = self._pre_minimize(omm_system, positions) + + # 7. Get lambdas + lambdas = self._get_lambda_schedule(settings) + + # 8. Add restraints + self._add_restraints(omm_system, omm_topology, settings) + + # 9. Get alchemical system + alchem_factory, alchem_system, alchem_indices = self._get_alchemical_system( + omm_topology, omm_system, comp_resids, alchem_comps + ) + + # 10. Get compound and sampler states + sampler_states, cmp_states = self._get_states( + alchem_system, positions, settings, + lambdas, solv_comp + ) + + # 11. Create the multistate reporter & create PDB + reporter = self._get_reporter( + omm_topology, positions, + settings['simulation_settings'], + ) + + # Wrap in try/finally to avoid memory leak issues + try: + # 12. Get context caches + energy_ctx_cache, sampler_ctx_cache = self._get_ctx_caches( + settings['engine_settings'] + ) + + # 13. Get integrator + integrator = self._get_integrator(settings['integrator_settings']) + + # 14. Get sampler + sampler = self._get_sampler( + integrator, reporter, settings['sampler_settings'], + cmp_states, sampler_states, + energy_ctx_cache, sampler_ctx_cache + ) + + # 15. Run simulation + unit_result_dict = self._run_simulation( + sampler, reporter, settings, dry + ) + + finally: + # close reporter when you're done to prevent file handle clashes + reporter.close() + + # clear GPU context + # Note: use cache.empty() when openmmtools #690 is resolved + for context in list(energy_ctx_cache._lru._data.keys()): + del energy_ctx_cache._lru._data[context] + for context in list(sampler_ctx_cache._lru._data.keys()): + del sampler_ctx_cache._lru._data[context] + # cautiously clear out the global context cache too + for context in list( + openmmtools.cache.global_context_cache._lru._data.keys()): + del openmmtools.cache.global_context_cache._lru._data[context] + + del sampler_ctx_cache, energy_ctx_cache + + # Keep these around in a dry run so we can inspect things + if not dry: + del integrator, sampler + + if not dry: + nc = self.shared_basepath / settings['simulation_settings'].output_filename + chk = self.shared_basepath / settings['simulation_settings'].checkpoint_storage + return { + 'nc': nc, + 'last_checkpoint': chk, + **unit_result_dict, + } + else: + return {'debug': {'sampler': sampler}} diff --git a/openfe/protocols/openmm_afe/equil_afe_settings.py b/openfe/protocols/openmm_afe/equil_afe_settings.py new file mode 100644 index 000000000..6a473f73d --- /dev/null +++ b/openfe/protocols/openmm_afe/equil_afe_settings.py @@ -0,0 +1,131 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe + +"""Settings class for equilibrium AFE Protocols using OpenMM + OpenMMTools + +This module implements the necessary settings necessary to run absolute free +energies using OpenMM. + +See Also +-------- +openfe.protocols.openmm_afe.AbsoluteSolvationProtocol + +TODO +---- +* Add support for restraints + +""" +from gufe.settings import ( + Settings, + SettingsBaseModel, + OpenMMSystemGeneratorFFSettings, + ThermoSettings, +) +from openfe.protocols.openmm_utils.omm_settings import ( + SystemSettings, + SolvationSettings, + AlchemicalSamplerSettings, + OpenMMEngineSettings, + IntegratorSettings, + SimulationSettings +) + + +try: + from pydantic.v1 import validator +except ImportError: + from pydantic import validator # type: ignore[assignment] + + +class AlchemicalSettings(SettingsBaseModel): + """Settings for the alchemical protocol + + These settings describe the lambda schedule and the creation of the + hybrid system. + """ + + lambda_elec_windows = 12 + """Number of lambda electrostatic alchemical steps, default 12""" + lambda_vdw_windows = 12 + """Number of lambda vdw alchemical steps, default 12""" + + @validator('lambda_elec_windows', 'lambda_vdw_windows') + def must_be_positive(cls, v): + if v <= 0: + errmsg = ("Number of lambda steps must be positive ") + raise ValueError(errmsg) + return v + + +class AbsoluteSolvationSettings(Settings): + """ + Configuration object for ``AbsoluteSolvationProtocol``. + + See Also + -------- + openfe.protocols.openmm_afe.AbsoluteSolvationProtocol + """ + class Config: + arbitrary_types_allowed = True + + # Inherited things + forcefield_settings: OpenMMSystemGeneratorFFSettings + """Parameters to set up the force field with OpenMM Force Fields""" + thermo_settings: ThermoSettings + """Settings for thermodynamic parameters""" + + # Things for creating the systems + vacuum_system_settings: SystemSettings + """ + Simulation system settings including the + long-range non-bonded methods for the vacuum transformation. + """ + solvent_system_settings: SystemSettings + """ + Simulation system settings including the + long-range non-bonded methods for the solvent transformation. + """ + solvation_settings: SolvationSettings + """Settings for solvating the system.""" + + # Alchemical settings + alchemical_settings: AlchemicalSettings + """ + Alchemical protocol settings including lambda windows. + """ + alchemsampler_settings: AlchemicalSamplerSettings + """ + Settings for controling how we sample alchemical space, including the + number of repeats. + """ + + # MD Engine things + vacuum_engine_settings: OpenMMEngineSettings + """ + Settings specific to the OpenMM engine, such as the compute platform + for the vacuum transformation. + """ + solvent_engine_settings: OpenMMEngineSettings + """ + Settings specific to the OpenMM engine, such as the compute platform + for the solvent transformation. + """ + + # Sampling State defining things + integrator_settings: IntegratorSettings + """ + Settings for controlling the integrator, such as the timestep and + barostat settings. + """ + + # Simulation run settings + vacuum_simulation_settings: SimulationSettings + """ + Simulation control settings, including simulation lengths and + record-keeping for the vacuum transformation. + """ + solvent_simulation_settings: SimulationSettings + """ + Simulation control settings, including simulation lengths and + record-keeping for the solvent transformation. + """ diff --git a/openfe/protocols/openmm_afe/equil_solvation_afe_method.py b/openfe/protocols/openmm_afe/equil_solvation_afe_method.py new file mode 100644 index 000000000..c259562b4 --- /dev/null +++ b/openfe/protocols/openmm_afe/equil_solvation_afe_method.py @@ -0,0 +1,727 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +"""OpenMM Equilibrium Solvation AFE Protocol --- :mod:`openfe.protocols.openmm_afe.equil_solvation_afe_method` +=============================================================================================================== + +This module implements the necessary methodology tooling to run calculate an +absolute solvation free energy using OpenMM tools and one of the following +alchemical sampling methods: + +* Hamiltonian Replica Exchange +* Self-adjusted mixture sampling +* Independent window sampling + +Current limitations +------------------- +* Disapearing molecules are only allowed in state A. Support for + appearing molecules will be added in due course. +* Only small molecules are allowed to act as alchemical molecules. + Alchemically changing protein or solvent components would induce + perturbations which are too large to be handled by this Protocol. + + +Acknowledgements +---------------- +* Originally based on hydration.py in + `espaloma `_ + +""" +from __future__ import annotations + +import logging + +from collections import defaultdict +import gufe +from gufe.components import Component +import itertools +import numpy as np +import numpy.typing as npt +from openff.units import unit +from typing import Dict, Optional, Union +from typing import Any, Iterable + +from gufe import ( + settings, + ChemicalSystem, SmallMoleculeComponent, + ProteinComponent, SolventComponent +) +from openfe.protocols.openmm_afe.equil_afe_settings import ( + AbsoluteSolvationSettings, SystemSettings, + SolvationSettings, AlchemicalSettings, + AlchemicalSamplerSettings, OpenMMEngineSettings, + IntegratorSettings, SimulationSettings, + SettingsBaseModel, +) +from ..openmm_utils import system_validation, settings_validation +from .base import BaseAbsoluteUnit +from openfe.utils import without_oechem_backend, log_system_probe + +logger = logging.getLogger(__name__) + + +class AbsoluteSolvationProtocolResult(gufe.ProtocolResult): + """Dict-like container for the output of a AbsoluteSolvationProtocol + """ + def __init__(self, **data): + super().__init__(**data) + # TODO: Detect when we have extensions and stitch these together? + if any(len(pur_list) > 2 for pur_list + in itertools.chain(self.data['solvent'].values(), self.data['vacuum'].values())): + raise NotImplementedError("Can't stitch together results yet") + + def get_individual_estimates(self) -> dict[str, list[tuple[unit.Quantity, unit.Quantity]]]: + """ + Get the individual estimate of the free energies. + + Returns + ------- + dGs : dict[str, list[tuple[unit.Quantity, unit.Quantity]]] + A dictionary, keyed `solvent` and `vacuum` for each leg + of the thermodynamic cycle, with lists of tuples containing + the individual free energy estimates and associated MBAR + uncertainties for each repeat of that simulation type. + """ + vac_dGs = [] + solv_dGs = [] + + for pus in self.data['vacuum'].values(): + vac_dGs.append(( + pus[0].outputs['unit_estimate'], + pus[0].outputs['unit_estimate_error'] + )) + + for pus in self.data['solvent'].values(): + solv_dGs.append(( + pus[0].outputs['unit_estimate'], + pus[0].outputs['unit_estimate_error'] + )) + + return {'solvent': solv_dGs, 'vacuum': vac_dGs} + + def get_estimate(self): + """Get the solvation free energy estimate for this calculation. + + Returns + ------- + dG : unit.Quantity + The solvation free energy. This is a Quantity defined with units. + """ + def _get_average(estimates): + # Get the unit value of the first value in the estimates + u = estimates[0][0].u + # Loop through estimates and get the free energy values + # in the unit of the first estimate + dGs = [i[0].to(u).m for i in estimates] + + return np.average(dGs) * u + + individual_estimates = self.get_individual_estimates() + vac_dG = _get_average(individual_estimates['vacuum']) + solv_dG = _get_average(individual_estimates['solvent']) + + return vac_dG - solv_dG + + def get_uncertainty(self): + """Get the solvation free energy error for this calculation. + + Returns + ------- + err : unit.Quantity + The standard deviation between estimates of the solvation free + energy. This is a Quantity defined with units. + """ + def _get_stdev(estimates): + # Get the unit value of the first value in the estimates + u = estimates[0][0].u + # Loop through estimates and get the free energy values + # in the unit of the first estimate + dGs = [i[0].to(u).m for i in estimates] + + return np.std(dGs) * u + + individual_estimates = self.get_individual_estimates() + vac_err = _get_stdev(individual_estimates['vacuum']) + solv_err = _get_stdev(individual_estimates['solvent']) + + # return the combined error + return np.sqrt(vac_err**2 + solv_err**2) + + def get_forward_and_reverse_energy_analysis(self) -> dict[str, list[dict[str, Union[npt.NDArray, unit.Quantity]]]]: + """ + Get the reverse and forward analysis of the free energies. + + Returns + ------- + forward_reverse : dict[str, list[dict[str, Union[npt.NDArray, unit.Quantity]]]] + A dictionary, keyed `solvent` and `vacuum` for each leg of the + thermodynamic cycle which each contain a list of dictionaries + containing the forward and reverse analysis of each repeat + of that simulation type. + + The forward and reverse analysis dictionaries contain: + - `fractions`: npt.NDArray + The fractions of data used for the estimates + - `forward_DGs`, `reverse_DGs`: unit.Quantity + The forward and reverse estimates for each fraction of data + - `forward_dDGs`, `reverse_dDGs`: unit.Quantity + The forward and reverse estimate uncertainty for each + fraction of data. + """ + + forward_reverse: dict[str, list[dict[str, Union[npt.NDArray, unit.Quantity]]]] = {} + + for key in ['solvent', 'vacuum']: + forward_reverse[key] = [ + pus[0].outputs['forward_and_reverse_energies'] + for pus in self.data[key].values() + ] + + return forward_reverse + + def get_overlap_matrices(self) -> dict[str, list[dict[str, npt.NDArray]]]: + """ + Get a the MBAR overlap estimates for all legs of the simulation. + + Returns + ------- + overlap_stats : dict[str, list[dict[str, npt.NDArray]]] + A dictionary with keys `solvent` and `vacuum` for each + leg of the thermodynamic cycle, which each containing a + list of dictionaries with the MBAR overlap estimates of + each repeat of that simulation type. + + The underlying MBAR dictionaries contain the following keys: + * ``scalar``: One minus the largest nontrivial eigenvalue + * ``eigenvalues``: The sorted (descending) eigenvalues of the + overlap matrix + * ``matrix``: Estimated overlap matrix of observing a sample from + state i in state j + """ + # Loop through and get the repeats and get the matrices + overlap_stats: dict[str, list[dict[str, npt.NDArray]]] = {} + + for key in ['solvent', 'vacuum']: + overlap_stats[key] = [ + pus[0].outputs['unit_mbar_overlap'] + for pus in self.data[key].values() + ] + + return overlap_stats + + def get_replica_transition_statistics(self) -> dict[str, list[dict[str, npt.NDArray]]]: + """ + Get the replica exchange transition statistics for all + legs of the simulation. + + Note + ---- + This is currently only available in cases where a replica exchange + simulation was run. + + Returns + ------- + repex_stats : dict[str, list[dict[str, npt.NDArray]]] + A dictionary with keys `solvent` and `vacuum` for each + leg of the thermodynamic cycle, which each containing + a list of dictionaries containing the replica transition + statistics for each repeat of that simulation type. + + The replica transition statistics dictionaries contain the following: + * ``eigenvalues``: The sorted (descending) eigenvalues of the + lambda state transition matrix + * ``matrix``: The transition matrix estimate of a replica switching + from state i to state j. + """ + repex_stats: dict[str, list[dict[str, npt.NDArray]]] = {} + try: + for key in ['solvent', 'vacuum']: + repex_stats[key] = [ + pus[0].outputs['replica_exchange_statistics'] + for pus in self.data[key].values() + ] + except KeyError: + errmsg = ("Replica exchange statistics were not found, " + "did you run a repex calculation?") + raise ValueError(errmsg) + + return repex_stats + + def get_replica_states(self) -> dict[str, list[npt.NDArray]]: + """ + Get the timeseries of replica states for all simulation legs. + + Returns + ------- + replica_states : dict[str, list[npt.NDArray]] + Dictionary keyed `solvent` and `vacuum` for each leg of + the thermodynamic cycle, with lists of replica states + timeseries for each repeat of that simulation type. + """ + replica_states: dict[str, list[npt.NDArray]] = {} + + for key in ['solvent', 'vacuum']: + replica_states[key] = [ + pus[0].outputs['replica_states'] + for pus in self.data[key].values() + ] + return replica_states + + def equilibration_iterations(self) -> dict[str, list[float]]: + """ + Get the number of equilibration iterations for each simulation. + + Returns + ------- + equilibration_lengths : dict[str, list[float]] + Dictionary keyed `solvent` and `vacuum` for each leg + of the thermodynamic cycle, with lists containing the + number of equilibration iterations for each repeat + of that simulation type. + """ + equilibration_lengths: dict[str, list[float]] = {} + + for key in ['solvent', 'vacuum']: + equilibration_lengths[key] = [ + pus[0].outputs['equilibration_iterations'] + for pus in self.data[key].values() + ] + + return equilibration_lengths + + def production_iterations(self) -> dict[str, list[float]]: + """ + Get the number of production iterations for each simulation. + Returns the number of uncorrelated production samples for each + repeat of the calculation. + + Returns + ------- + production_lengths : dict[str, list[float]] + Dictionary keyed `solvent` and `vacuum` for each leg of the + thermodynamic cycle, with lists with the number + of production iterations for each repeat of that simulation + type. + """ + production_lengths: dict[str, list[float]] = {} + + for key in ['solvent', 'vacuum']: + production_lengths[key] = [ + pus[0].outputs['production_iterations'] + for pus in self.data[key].values() + ] + + return production_lengths + + +class AbsoluteSolvationProtocol(gufe.Protocol): + """ + Absolute solvation free energy calculations using OpenMM and OpenMMTools. + + See Also + -------- + openfe.protocols + openfe.protocols.openmm_afe.AbsoluteSolvationSettings + openfe.protocols.openmm_afe.AbsoluteSolvationProtocolResult + openfe.protocols.openmm_afe.AbsoluteSolvationVacuumUnit + openfe.protocols.openmm_afe.AbsoluteSolvationSolventUnit + """ + result_cls = AbsoluteSolvationProtocolResult + _settings: AbsoluteSolvationSettings + + @classmethod + def _default_settings(cls): + """A dictionary of initial settings for this creating this Protocol + + These settings are intended as a suitable starting point for creating + an instance of this protocol. It is recommended, however that care is + taken to inspect and customize these before performing a Protocol. + + Returns + ------- + Settings + a set of default settings + """ + return AbsoluteSolvationSettings( + forcefield_settings=settings.OpenMMSystemGeneratorFFSettings(), + thermo_settings=settings.ThermoSettings( + temperature=298.15 * unit.kelvin, + pressure=1 * unit.bar, + ), + solvent_system_settings=SystemSettings(), + vacuum_system_settings=SystemSettings(nonbonded_method='nocutoff'), + alchemical_settings=AlchemicalSettings(), + alchemsampler_settings=AlchemicalSamplerSettings( + n_replicas=24, + ), + solvation_settings=SolvationSettings(), + vacuum_engine_settings=OpenMMEngineSettings(), + solvent_engine_settings=OpenMMEngineSettings(), + integrator_settings=IntegratorSettings(), + solvent_simulation_settings=SimulationSettings( + equilibration_length=1.0 * unit.nanosecond, + production_length=10.0 * unit.nanosecond, + output_filename='solvent.nc', + checkpoint_storage='solvent_checkpoint.nc', + ), + vacuum_simulation_settings=SimulationSettings( + equilibration_length=0.5 * unit.nanosecond, + production_length=2.0 * unit.nanosecond, + output_filename='vacuum.nc', + checkpoint_storage='vacuum_checkpoint.nc' + ), + ) + + @staticmethod + def _validate_solvent_endstates( + stateA: ChemicalSystem, stateB: ChemicalSystem, + ) -> None: + """ + A solvent transformation is defined (in terms of gufe components) + as starting from one or more ligands in solvent and + ending up in a state with one less ligand. + + No protein components are allowed. + + Parameters + ---------- + stateA : ChemicalSystem + The chemical system of end state A + stateB : ChemicalSystem + The chemical system of end state B + + Raises + ------ + ValueError + If stateA or stateB contains a ProteinComponent + If there is no SolventComponent in either stateA or stateB + """ + # Check that there are no protein components + for comp in itertools.chain(stateA.values(), stateB.values()): + if isinstance(comp, ProteinComponent): + errmsg = ("Protein components are not allowed for " + "absolute solvation free energies") + raise ValueError(errmsg) + + # check that there is a solvent component + if not any( + isinstance(comp, SolventComponent) for comp in stateA.values() + ): + errmsg = "No SolventComponent found in stateA" + raise ValueError(errmsg) + + if not any( + isinstance(comp, SolventComponent) for comp in stateB.values() + ): + errmsg = "No SolventComponent found in stateB" + raise ValueError(errmsg) + + @staticmethod + def _validate_alchemical_components( + alchemical_components: dict[str, list[Component]] + ) -> None: + """ + Checks that the ChemicalSystem alchemical components are correct. + + Parameters + ---------- + alchemical_components : Dict[str, list[Component]] + Dictionary containing the alchemical components for + stateA and stateB. + + Raises + ------ + ValueError + If there are alchemical components in state B. + If there are non SmallMoleculeComponent alchemical species. + If there are more than one alchemical species. + + Notes + ----- + * Currently doesn't support alchemical components in state B. + * Currently doesn't support alchemical components which are not + SmallMoleculeComponents. + * Currently doesn't support more than one alchemical component + being desolvated. + """ + + # Crash out if there are any alchemical components in state B for now + if len(alchemical_components['stateB']) > 0: + errmsg = ("Components appearing in state B are not " + "currently supported") + raise ValueError(errmsg) + + if len(alchemical_components['stateA']) > 1: + errmsg = ("More than one alchemical components is not supported " + "for absolute solvation free energies") + raise ValueError(errmsg) + + # Crash out if any of the alchemical components are not + # SmallMoleculeComponent + for comp in alchemical_components['stateA']: + if not isinstance(comp, SmallMoleculeComponent): + errmsg = ("Non SmallMoleculeComponent alchemical species " + "are not currently supported") + raise ValueError(errmsg) + + def _create( + self, + stateA: ChemicalSystem, + stateB: ChemicalSystem, + mapping: Optional[Dict[str, gufe.ComponentMapping]] = None, + extends: Optional[gufe.ProtocolDAGResult] = None, + ) -> list[gufe.ProtocolUnit]: + # TODO: extensions + if extends: # pragma: no-cover + raise NotImplementedError("Can't extend simulations yet") + + # Validate components and get alchemical components + self._validate_solvent_endstates(stateA, stateB) + alchem_comps = system_validation.get_alchemical_components( + stateA, stateB, + ) + self._validate_alchemical_components(alchem_comps) + + # Check nonbond & solvent compatibility + solv_nonbonded_method = self.settings.solvent_system_settings.nonbonded_method + vac_nonbonded_method = self.settings.vacuum_system_settings.nonbonded_method + # Use the more complete system validation solvent checks + system_validation.validate_solvent(stateA, solv_nonbonded_method) + # Gas phase is always gas phase + if vac_nonbonded_method.lower() != 'nocutoff': + errmsg = ("Only the nocutoff nonbonded_method is supported for " + f"vacuum calculations, {vac_nonbonded_method} was " + "passed") + raise ValueError(errmsg) + + # Get the name of the alchemical species + alchname = alchem_comps['stateA'][0].name + + # Create list units for vacuum and solvent transforms + + solvent_units = [ + AbsoluteSolvationSolventUnit( + stateA=stateA, stateB=stateB, + settings=self.settings, + alchemical_components=alchem_comps, + generation=0, repeat_id=i, + name=(f"Absolute Solvation, {alchname} solvent leg: " + f"repeat {i} generation 0"), + ) + for i in range(self.settings.alchemsampler_settings.n_repeats) + ] + + vacuum_units = [ + AbsoluteSolvationVacuumUnit( + # These don't really reflect the actual transform + # Should these be overriden to be ChemicalSystem{smc} -> ChemicalSystem{} ? + stateA=stateA, stateB=stateB, + settings=self.settings, + alchemical_components=alchem_comps, + generation=0, repeat_id=i, + name=(f"Absolute Solvation, {alchname} vacuum leg: " + f"repeat {i} generation 0"), + ) + for i in range(self.settings.alchemsampler_settings.n_repeats) + ] + + return solvent_units + vacuum_units + + def _gather( + self, protocol_dag_results: Iterable[gufe.ProtocolDAGResult] + ) -> Dict[str, Dict[str, Any]]: + # result units will have a repeat_id and generation + # first group according to repeat_id + unsorted_solvent_repeats = defaultdict(list) + unsorted_vacuum_repeats = defaultdict(list) + for d in protocol_dag_results: + pu: gufe.ProtocolUnitResult + for pu in d.protocol_unit_results: + if not pu.ok(): + continue + if pu.outputs['simtype'] == 'solvent': + unsorted_solvent_repeats[pu.outputs['repeat_id']].append(pu) + else: + unsorted_vacuum_repeats[pu.outputs['repeat_id']].append(pu) + + repeats: dict[str, dict[str, list[gufe.ProtocolUnitResult]]] = { + 'solvent': {}, 'vacuum': {}, + } + for k, v in unsorted_solvent_repeats.items(): + repeats['solvent'][str(k)] = sorted(v, key=lambda x: x.outputs['generation']) + + for k, v in unsorted_vacuum_repeats.items(): + repeats['vacuum'][str(k)] = sorted(v, key=lambda x: x.outputs['generation']) + return repeats + + +class AbsoluteSolvationVacuumUnit(BaseAbsoluteUnit): + def _get_components(self) -> tuple[dict[str, list[Component]], None, + Optional[ProteinComponent], + list[SmallMoleculeComponent]]: + """ + Get the relevant components for a vacuum transformation. + + Returns + ------- + alchem_comps : dict[str, list[Component]] + A list of alchemical components + solv_comp : None + For the gas phase transformation, None will always be returned + for the solvent component of the chemical system. + prot_comp : Optional[ProteinComponent] + The protein component of the system, if it exists. + small_mols : list[SmallMoleculeComponent] + A list of SmallMoleculeComponents to add to the system. This + is equivalent to the alchemical components in stateA (since + we only allow for disappearing ligands). + """ + stateA = self._inputs['stateA'] + alchem_comps = self._inputs['alchemical_components'] + + _, prot_comp, _ = system_validation.get_components(stateA) + + # Notes: + # 1. Our input state will contain a solvent, we ``None`` that out + # since this is the gas phase unit. + # 2. Our small molecules will always just be the alchemical components + # (of stateA since we enforce only one disappearing ligand) + return alchem_comps, None, prot_comp, alchem_comps['stateA'] + + def _handle_settings(self) -> dict[str, SettingsBaseModel]: + """ + Extract the relevant settings for a vacuum transformation. + + Returns + ------- + settings : dict[str, SettingsBaseModel] + A dictionary with the following entries: + * forcefield_settings : OpenMMSystemGeneratorFFSettings + * thermo_settings : ThermoSettings + * system_settings : SystemSettings + * solvation_settings : SolvationSettings + * alchemical_settings : AlchemicalSettings + * sampler_settings : AlchemicalSamplerSettings + * engine_settings : OpenMMEngineSettings + * integrator_settings : IntegratorSettings + * simulation_settings : SimulationSettings + """ + prot_settings = self._inputs['settings'] + + settings = {} + settings['forcefield_settings'] = prot_settings.forcefield_settings + settings['thermo_settings'] = prot_settings.thermo_settings + settings['system_settings'] = prot_settings.vacuum_system_settings + settings['solvation_settings'] = prot_settings.solvation_settings + settings['alchemical_settings'] = prot_settings.alchemical_settings + settings['sampler_settings'] = prot_settings.alchemsampler_settings + settings['engine_settings'] = prot_settings.vacuum_engine_settings + settings['integrator_settings'] = prot_settings.integrator_settings + settings['simulation_settings'] = prot_settings.vacuum_simulation_settings + + settings_validation.validate_timestep( + settings['forcefield_settings'].hydrogen_mass, + settings['integrator_settings'].timestep + ) + + return settings + + def _execute( + self, ctx: gufe.Context, **kwargs, + ) -> Dict[str, Any]: + log_system_probe(logging.INFO, paths=[ctx.scratch]) + + with without_oechem_backend(): + outputs = self.run(scratch_basepath=ctx.scratch, + shared_basepath=ctx.shared) + + return { + 'repeat_id': self._inputs['repeat_id'], + 'generation': self._inputs['generation'], + 'simtype': 'vacuum', + **outputs + } + + +class AbsoluteSolvationSolventUnit(BaseAbsoluteUnit): + def _get_components(self) -> tuple[list[Component], SolventComponent, + Optional[ProteinComponent], + list[SmallMoleculeComponent]]: + """ + Get the relevant components for a vacuum transformation. + + Returns + ------- + alchem_comps : list[Component] + A list of alchemical components + solv_comp : SolventComponent + The SolventComponent of the system + prot_comp : Optional[ProteinComponent] + The protein component of the system, if it exists. + small_mols : list[SmallMoleculeComponent] + A list of SmallMoleculeComponents to add to the system. + """ + stateA = self._inputs['stateA'] + alchem_comps = self._inputs['alchemical_components'] + + solv_comp, prot_comp, small_mols = system_validation.get_components(stateA) + + # We don't need to check that solv_comp is not None, otherwise + # an error will have been raised when calling `validate_solvent` + # in the Protocol's `_create`. + # Similarly we don't need to check prot_comp since that's also + # disallowed on create + return alchem_comps, solv_comp, prot_comp, small_mols + + def _handle_settings(self) -> dict[str, SettingsBaseModel]: + """ + Extract the relevant settings for a vacuum transformation. + + Returns + ------- + settings : dict[str, SettingsBaseModel] + A dictionary with the following entries: + * forcefield_settings : OpenMMSystemGeneratorFFSettings + * thermo_settings : ThermoSettings + * system_settings : SystemSettings + * solvation_settings : SolvationSettings + * alchemical_settings : AlchemicalSettings + * sampler_settings : AlchemicalSamplerSettings + * engine_settings : OpenMMEngineSettings + * integrator_settings : IntegratorSettings + * simulation_settings : SimulationSettings + """ + prot_settings = self._inputs['settings'] + + settings = {} + settings['forcefield_settings'] = prot_settings.forcefield_settings + settings['thermo_settings'] = prot_settings.thermo_settings + settings['system_settings'] = prot_settings.solvent_system_settings + settings['solvation_settings'] = prot_settings.solvation_settings + settings['alchemical_settings'] = prot_settings.alchemical_settings + settings['sampler_settings'] = prot_settings.alchemsampler_settings + settings['engine_settings'] = prot_settings.solvent_engine_settings + settings['integrator_settings'] = prot_settings.integrator_settings + settings['simulation_settings'] = prot_settings.solvent_simulation_settings + + settings_validation.validate_timestep( + settings['forcefield_settings'].hydrogen_mass, + settings['integrator_settings'].timestep + ) + + return settings + + def _execute( + self, ctx: gufe.Context, **kwargs, + ) -> Dict[str, Any]: + log_system_probe(logging.INFO, paths=[ctx.scratch]) + + with without_oechem_backend(): + outputs = self.run(scratch_basepath=ctx.scratch, + shared_basepath=ctx.shared) + + return { + 'repeat_id': self._inputs['repeat_id'], + 'generation': self._inputs['generation'], + 'simtype': 'solvent', + **outputs + } diff --git a/openfe/protocols/openmm_rfe/_rfe_utils/compute.py b/openfe/protocols/openmm_rfe/_rfe_utils/compute.py index ad4e5b507..b3bee28f6 100644 --- a/openfe/protocols/openmm_rfe/_rfe_utils/compute.py +++ b/openfe/protocols/openmm_rfe/_rfe_utils/compute.py @@ -30,6 +30,15 @@ def get_openmm_platform(platform_name=None): from openmmtools.utils import get_fastest_platform platform = get_fastest_platform(minimum_precision='mixed') else: + try: + platform_name = { + 'cpu': 'CPU', + 'opencl': 'OpenCL', + 'cuda': 'CUDA', + }[str(platform_name).lower()] + except KeyError: + pass + from openmm import Platform platform = Platform.getPlatformByName(platform_name) # Set precision and properties diff --git a/openfe/protocols/openmm_rfe/equil_rfe_settings.py b/openfe/protocols/openmm_rfe/equil_rfe_settings.py index 230d05ff6..d16aeea38 100644 --- a/openfe/protocols/openmm_rfe/equil_rfe_settings.py +++ b/openfe/protocols/openmm_rfe/equil_rfe_settings.py @@ -2,7 +2,7 @@ # For details, see https://github.com/OpenFreeEnergy/openfe """Equilibrium Relative Free Energy Protocol input settings. -This module implements the necessary settings necessary to run absolute free +This module implements the necessary settings necessary to run relative free energies using :class:`openfe.protocols.openmm_rfe.equil_rfe_methods.py` """ diff --git a/openfe/protocols/openmm_utils/omm_settings.py b/openfe/protocols/openmm_utils/omm_settings.py index 428f768d6..063560814 100644 --- a/openfe/protocols/openmm_utils/omm_settings.py +++ b/openfe/protocols/openmm_utils/omm_settings.py @@ -1,10 +1,11 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe -"""Equilibrium Relative Free Energy Protocol input settings. - -This module implements the necessary settings necessary to run absolute free -energies using :class:`openfe.protocols.openmm_rfe.equil_rfe_methods.py` +"""Equilibrium Free Energy Protocols input settings. +This module implements base settings necessary to run +free energy calculations using OpenMM +/- Tools, such +as :mod:`openfe.protocols.openmm_rfe.equil_rfe_methods.py` +and :mod`openfe.protocols.openmm_afe.equil_afe_methods.py` """ from __future__ import annotations @@ -19,11 +20,13 @@ ThermoSettings, ) + try: from pydantic.v1 import validator except ImportError: from pydantic import validator # type: ignore[assignment] + class SystemSettings(SettingsBaseModel): """Settings describing the simulation system settings.""" diff --git a/openfe/tests/data/openmm_afe/CN_absolute_solvation_transformation.json.gz b/openfe/tests/data/openmm_afe/CN_absolute_solvation_transformation.json.gz new file mode 100644 index 000000000..c8be33477 Binary files /dev/null and b/openfe/tests/data/openmm_afe/CN_absolute_solvation_transformation.json.gz differ diff --git a/openfe/tests/data/openmm_afe/__init__.py b/openfe/tests/data/openmm_afe/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/openfe/tests/data/openmm_afe/__init__.py @@ -0,0 +1 @@ + diff --git a/openfe/tests/protocols/conftest.py b/openfe/tests/protocols/conftest.py index c18204ce0..2f46d9704 100644 --- a/openfe/tests/protocols/conftest.py +++ b/openfe/tests/protocols/conftest.py @@ -143,9 +143,21 @@ def toluene_many_solv_system(benzene_modifications): @pytest.fixture -def transformation_json() -> str: - """string of a result of quickrun""" +def rfe_transformation_json() -> str: + """string of a RFE result of quickrun""" d = resources.files('openfe.tests.data.openmm_rfe') with gzip.open((d / 'Transformation-e1702a3efc0fa735d5c14fc7572b5278_results.json.gz').as_posix(), 'r') as f: # type: ignore return f.read().decode() # type: ignore + + +@pytest.fixture +def afe_solv_transformation_json() -> str: + """ + string of a Absolute Solvation result (CN in water) generated by quickrun + """ + d = resources.files('openfe.tests.data.openmm_afe') + fname = "CN_absolute_solvation_transformation.json.gz" + + with gzip.open((d / fname).as_posix(), 'r') as f: # type: ignore + return f.read().decode() # type: ignore diff --git a/openfe/tests/protocols/test_openmm_afe_solvation_protocol.py b/openfe/tests/protocols/test_openmm_afe_solvation_protocol.py new file mode 100644 index 000000000..14a7b5c9c --- /dev/null +++ b/openfe/tests/protocols/test_openmm_afe_solvation_protocol.py @@ -0,0 +1,568 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +import itertools +import json +import pytest +from unittest import mock +from openmmtools.multistate.multistatesampler import MultiStateSampler +from openff.units import unit as offunit +import mdtraj as mdt +import numpy as np +import gufe +import openfe +from openfe import ChemicalSystem, SolventComponent +from openfe.protocols import openmm_afe +from openfe.protocols.openmm_afe import ( + AbsoluteSolvationSolventUnit, + AbsoluteSolvationVacuumUnit, + AbsoluteSolvationProtocol, +) +from openfe.protocols.openmm_utils import system_validation + + +@pytest.fixture() +def default_settings(): + return AbsoluteSolvationProtocol.default_settings() + + +def test_create_default_settings(): + settings = AbsoluteSolvationProtocol.default_settings() + assert settings + + +@pytest.mark.parametrize('val', [ + {'elec': 0, 'vdw': 5}, + {'elec': -2, 'vdw': 5}, + {'elec': 5, 'vdw': -2}, + {'elec': 5, 'vdw': 0}, +]) +def test_incorrect_window_settings(val, default_settings): + errmsg = "lambda steps must be positive" + alchem_settings = default_settings.alchemical_settings + with pytest.raises(ValueError, match=errmsg): + alchem_settings.lambda_elec_windows = val['elec'] + alchem_settings.lambda_vdw_windows = val['vdw'] + + +def test_create_default_protocol(default_settings): + # this is roughly how it should be created + protocol = AbsoluteSolvationProtocol( + settings=default_settings, + ) + assert protocol + + +def test_serialize_protocol(default_settings): + protocol = AbsoluteSolvationProtocol( + settings=default_settings, + ) + + ser = protocol.to_dict() + ret = AbsoluteSolvationProtocol.from_dict(ser) + assert protocol == ret + + +def test_validate_solvent_endstates_protcomp( + benzene_modifications,T4_protein_component +): + stateA = ChemicalSystem({ + 'benzene': benzene_modifications['benzene'], + 'protein': T4_protein_component, + 'solvent': SolventComponent() + }) + + stateB = ChemicalSystem({ + 'benzene': benzene_modifications['benzene'], + 'phenol': benzene_modifications['phenol'], + 'solvent': SolventComponent(), + }) + + with pytest.raises(ValueError, match="Protein components are not allowed"): + comps = AbsoluteSolvationProtocol._validate_solvent_endstates(stateA, stateB) + + +def test_validate_solvent_endstates_nosolvcomp_stateA( + benzene_modifications, T4_protein_component +): + stateA = ChemicalSystem({ + 'benzene': benzene_modifications['benzene'], + }) + + stateB = ChemicalSystem({ + 'benzene': benzene_modifications['benzene'], + 'phenol': benzene_modifications['phenol'], + 'solvent': SolventComponent(), + }) + + with pytest.raises( + ValueError, match="No SolventComponent found in stateA" + ): + comps = AbsoluteSolvationProtocol._validate_solvent_endstates(stateA, stateB) + + +def test_validate_solvent_endstates_nosolvcomp_stateB( + benzene_modifications, T4_protein_component +): + stateA = ChemicalSystem({ + 'benzene': benzene_modifications['benzene'], + 'solvent': SolventComponent(), + }) + + stateB = ChemicalSystem({ + 'benzene': benzene_modifications['benzene'], + 'phenol': benzene_modifications['phenol'], + }) + + with pytest.raises( + ValueError, match="No SolventComponent found in stateB" + ): + comps = AbsoluteSolvationProtocol._validate_solvent_endstates(stateA, stateB) + +def test_validate_alchem_comps_appearingB(benzene_modifications): + stateA = ChemicalSystem({ + 'solvent': SolventComponent() + }) + + stateB = ChemicalSystem({ + 'benzene': benzene_modifications['benzene'], + 'solvent': SolventComponent() + }) + + alchem_comps = system_validation.get_alchemical_components(stateA, stateB) + + with pytest.raises(ValueError, match='Components appearing in state B'): + AbsoluteSolvationProtocol._validate_alchemical_components(alchem_comps) + + +def test_validate_alchem_comps_multi(benzene_modifications): + stateA = ChemicalSystem({ + 'benzene': benzene_modifications['benzene'], + 'toluene': benzene_modifications['toluene'], + 'solvent': SolventComponent() + }) + + stateB = ChemicalSystem({ + 'solvent': SolventComponent() + }) + + alchem_comps = system_validation.get_alchemical_components(stateA, stateB) + + assert len(alchem_comps['stateA']) == 2 + + with pytest.raises(ValueError, match='More than one alchemical'): + AbsoluteSolvationProtocol._validate_alchemical_components(alchem_comps) + + +def test_validate_alchem_nonsmc(benzene_modifications): + stateA = ChemicalSystem({ + 'benzene': benzene_modifications['benzene'], + 'solvent': SolventComponent() + }) + + stateB = ChemicalSystem({ + 'benzene': benzene_modifications['benzene'], + }) + + alchem_comps = system_validation.get_alchemical_components(stateA, stateB) + + with pytest.raises(ValueError, match='Non SmallMoleculeComponent'): + AbsoluteSolvationProtocol._validate_alchemical_components(alchem_comps) + + +def test_vac_bad_nonbonded(benzene_modifications): + settings = openmm_afe.AbsoluteSolvationProtocol.default_settings() + settings.vacuum_system_settings.nonbonded_method = 'pme' + protocol = openmm_afe.AbsoluteSolvationProtocol(settings=settings) + + + stateA = ChemicalSystem({ + 'benzene': benzene_modifications['benzene'], + 'solvent': SolventComponent() + }) + + stateB = ChemicalSystem({ + 'solvent': SolventComponent(), + }) + + + with pytest.raises(ValueError, match='Only the nocutoff'): + protocol.create(stateA=stateA, stateB=stateB, mapping=None) + + +@pytest.mark.parametrize('method', [ + 'repex', 'sams', 'independent', 'InDePeNdENT' +]) +def test_dry_run_vac_benzene(benzene_modifications, + method, tmpdir): + s = openmm_afe.AbsoluteSolvationProtocol.default_settings() + s.alchemsampler_settings.n_repeats = 1 + s.alchemsampler_settings.sampler_method = method + + protocol = openmm_afe.AbsoluteSolvationProtocol( + settings=s, + ) + + stateA = ChemicalSystem({ + 'benzene': benzene_modifications['benzene'], + 'solvent': SolventComponent() + }) + + stateB = ChemicalSystem({ + 'solvent': SolventComponent(), + }) + + # Create DAG from protocol, get the vacuum and solvent units + # and eventually dry run the first vacuum unit + dag = protocol.create( + stateA=stateA, + stateB=stateB, + mapping=None, + ) + prot_units = list(dag.protocol_units) + + assert len(prot_units) == 2 + + vac_unit = [u for u in prot_units + if isinstance(u, AbsoluteSolvationVacuumUnit)] + sol_unit = [u for u in prot_units + if isinstance(u, AbsoluteSolvationSolventUnit)] + + assert len(vac_unit) == 1 + assert len(sol_unit) == 1 + + with tmpdir.as_cwd(): + vac_sampler = vac_unit[0].run(dry=True)['debug']['sampler'] + assert not vac_sampler.is_periodic + + +def test_dry_run_solv_benzene(benzene_modifications, tmpdir): + s = openmm_afe.AbsoluteSolvationProtocol.default_settings() + s.alchemsampler_settings.n_repeats = 1 + s.solvent_simulation_settings.output_indices = "resname UNK" + + protocol = openmm_afe.AbsoluteSolvationProtocol( + settings=s, + ) + + stateA = ChemicalSystem({ + 'benzene': benzene_modifications['benzene'], + 'solvent': SolventComponent() + }) + + stateB = ChemicalSystem({ + 'solvent': SolventComponent(), + }) + + # Create DAG from protocol, get the vacuum and solvent units + # and eventually dry run the first solvent unit + dag = protocol.create( + stateA=stateA, + stateB=stateB, + mapping=None, + ) + prot_units = list(dag.protocol_units) + + assert len(prot_units) == 2 + + vac_unit = [u for u in prot_units + if isinstance(u, AbsoluteSolvationVacuumUnit)] + sol_unit = [u for u in prot_units + if isinstance(u, AbsoluteSolvationSolventUnit)] + + assert len(vac_unit) == 1 + assert len(sol_unit) == 1 + + with tmpdir.as_cwd(): + sol_sampler = sol_unit[0].run(dry=True)['debug']['sampler'] + assert sol_sampler.is_periodic + + pdb = mdt.load_pdb('hybrid_system.pdb') + assert pdb.n_atoms == 12 + + +def test_dry_run_solv_benzene_tip4p(benzene_modifications, tmpdir): + s = AbsoluteSolvationProtocol.default_settings() + s.alchemsampler_settings.n_repeats = 1 + s.forcefield_settings.forcefields = [ + "amber/ff14SB.xml", # ff14SB protein force field + "amber/tip4pew_standard.xml", # FF we are testsing with the fun VS + "amber/phosaa10.xml", # Handles THE TPO + ] + s.solvation_settings.solvent_model = 'tip4pew' + s.integrator_settings.reassign_velocities = True + + protocol = AbsoluteSolvationProtocol( + settings=s, + ) + + stateA = ChemicalSystem({ + 'benzene': benzene_modifications['benzene'], + 'solvent': SolventComponent() + }) + + stateB = ChemicalSystem({ + 'solvent': SolventComponent(), + }) + + # Create DAG from protocol, get the vacuum and solvent units + # and eventually dry run the first solvent unit + dag = protocol.create( + stateA=stateA, + stateB=stateB, + mapping=None, + ) + prot_units = list(dag.protocol_units) + + sol_unit = [u for u in prot_units + if isinstance(u, AbsoluteSolvationSolventUnit)] + + with tmpdir.as_cwd(): + sol_sampler = sol_unit[0].run(dry=True)['debug']['sampler'] + assert sol_sampler.is_periodic + + +def test_nreplicas_lambda_mismatch(benzene_modifications, tmpdir): + s = AbsoluteSolvationProtocol.default_settings() + s.alchemsampler_settings.n_repeats = 1 + s.alchemsampler_settings.n_replicas = 12 + + protocol = AbsoluteSolvationProtocol( + settings=s, + ) + + stateA = ChemicalSystem({ + 'benzene': benzene_modifications['benzene'], + 'solvent': SolventComponent() + }) + + stateB = ChemicalSystem({ + 'solvent': SolventComponent(), + }) + + dag = protocol.create( + stateA=stateA, + stateB=stateB, + mapping=None, + ) + prot_units = list(dag.protocol_units) + + with tmpdir.as_cwd(): + errmsg = "Number of replicas 12" + with pytest.raises(ValueError, match=errmsg): + prot_units[0].run(dry=True) + + +def test_high_timestep(benzene_modifications, tmpdir): + s = AbsoluteSolvationProtocol.default_settings() + s.alchemsampler_settings.n_repeats = 1 + s.forcefield_settings.hydrogen_mass = 1.0 + + protocol = AbsoluteSolvationProtocol( + settings=s, + ) + + stateA = ChemicalSystem({ + 'benzene': benzene_modifications['benzene'], + 'solvent': SolventComponent() + }) + + stateB = ChemicalSystem({ + 'solvent': SolventComponent(), + }) + + dag = protocol.create( + stateA=stateA, + stateB=stateB, + mapping=None, + ) + prot_units = list(dag.protocol_units) + + with tmpdir.as_cwd(): + errmsg = "too large for hydrogen mass" + with pytest.raises(ValueError, match=errmsg): + prot_units[0].run(dry=True) + + +@pytest.fixture +def benzene_solvation_dag(benzene_modifications): + s = AbsoluteSolvationProtocol.default_settings() + + protocol = openmm_afe.AbsoluteSolvationProtocol( + settings=s, + ) + + stateA = ChemicalSystem({ + 'benzene': benzene_modifications['benzene'], + 'solvent': SolventComponent() + }) + + stateB = ChemicalSystem({ + 'solvent': SolventComponent(), + }) + + return protocol.create(stateA=stateA, stateB=stateB, mapping=None) + + +def test_unit_tagging(benzene_solvation_dag, tmpdir): + # test that executing the units includes correct gen and repeat info + + dag_units = benzene_solvation_dag.protocol_units + + with ( + mock.patch('openfe.protocols.openmm_afe.equil_solvation_afe_method.AbsoluteSolvationSolventUnit.run', + return_value={'nc': 'file.nc', 'last_checkpoint': 'chck.nc'}), + mock.patch('openfe.protocols.openmm_afe.equil_solvation_afe_method.AbsoluteSolvationVacuumUnit.run', + return_value={'nc': 'file.nc', 'last_checkpoint': 'chck.nc'}), + ): + results = [] + for u in dag_units: + ret = u.execute(context=gufe.Context(tmpdir, tmpdir)) + results.append(ret) + + solv_repeats = set() + vac_repeats = set() + for ret in results: + assert isinstance(ret, gufe.ProtocolUnitResult) + assert ret.outputs['generation'] == 0 + if ret.outputs['simtype'] == 'vacuum': + vac_repeats.add(ret.outputs['repeat_id']) + else: + solv_repeats.add(ret.outputs['repeat_id']) + assert vac_repeats == {0, 1, 2} + assert solv_repeats == {0, 1, 2} + + +def test_gather(benzene_solvation_dag, tmpdir): + # check that .gather behaves as expected + with ( + mock.patch('openfe.protocols.openmm_afe.equil_solvation_afe_method.AbsoluteSolvationSolventUnit.run', + return_value={'nc': 'file.nc', 'last_checkpoint': 'chck.nc'}), + mock.patch('openfe.protocols.openmm_afe.equil_solvation_afe_method.AbsoluteSolvationVacuumUnit.run', + return_value={'nc': 'file.nc', 'last_checkpoint': 'chck.nc'}), + ): + dagres = gufe.protocols.execute_DAG(benzene_solvation_dag, + shared_basedir=tmpdir, + scratch_basedir=tmpdir, + keep_shared=True) + + protocol = AbsoluteSolvationProtocol( + settings=AbsoluteSolvationProtocol.default_settings(), + ) + + res = protocol.gather([dagres]) + + assert isinstance(res, openmm_afe.AbsoluteSolvationProtocolResult) + + +class TestProtocolResult: + @pytest.fixture() + def protocolresult(self, afe_solv_transformation_json): + d = json.loads(afe_solv_transformation_json, + cls=gufe.tokenization.JSON_HANDLER.decoder) + + pr = openfe.ProtocolResult.from_dict(d['protocol_result']) + + return pr + + def test_reload_protocol_result(self, afe_solv_transformation_json): + d = json.loads(afe_solv_transformation_json, + cls=gufe.tokenization.JSON_HANDLER.decoder) + + pr = openmm_afe.AbsoluteSolvationProtocolResult.from_dict(d['protocol_result']) + + assert pr + + def test_get_estimate(self, protocolresult): + est = protocolresult.get_estimate() + + assert est + assert est.m == pytest.approx(-3.00208997) + assert isinstance(est, offunit.Quantity) + assert est.is_compatible_with(offunit.kilojoule_per_mole) + + def test_get_uncertainty(self, protocolresult): + est = protocolresult.get_uncertainty() + + assert est + assert est.m == pytest.approx(0.1577349) + assert isinstance(est, offunit.Quantity) + assert est.is_compatible_with(offunit.kilojoule_per_mole) + + def test_get_individual(self, protocolresult): + inds = protocolresult.get_individual_estimates() + + assert isinstance(inds, dict) + assert isinstance(inds['solvent'], list) + assert isinstance(inds['vacuum'], list) + assert len(inds['solvent']) == len(inds['vacuum']) == 3 + for e, u in itertools.chain(inds['solvent'], inds['vacuum']): + assert e.is_compatible_with(offunit.kilojoule_per_mole) + assert u.is_compatible_with(offunit.kilojoule_per_mole) + + @pytest.mark.parametrize('key', ['solvent', 'vacuum']) + def test_get_forwards_etc(self, key, protocolresult): + far = protocolresult.get_forward_and_reverse_energy_analysis() + + assert isinstance(far, dict) + assert isinstance(far[key], list) + far1 = far[key][0] + assert isinstance(far1, dict) + + for k in ['fractions', 'forward_DGs', 'forward_dDGs', + 'reverse_DGs', 'reverse_dDGs']: + assert k in far1 + + if k == 'fractions': + assert isinstance(far1[k], np.ndarray) + + @pytest.mark.parametrize('key', ['solvent', 'vacuum']) + def test_get_overlap_matrices(self, key, protocolresult): + ovp = protocolresult.get_overlap_matrices() + + assert isinstance(ovp, dict) + assert isinstance(ovp[key], list) + assert len(ovp[key]) == 3 + + ovp1 = ovp[key][0] + assert isinstance(ovp1['matrix'], np.ndarray) + assert ovp1['matrix'].shape == (15, 15) + + @pytest.mark.parametrize('key', ['solvent', 'vacuum']) + def test_get_replica_transition_statistics(self, key, protocolresult): + rpx = protocolresult.get_replica_transition_statistics() + + assert isinstance(rpx, dict) + assert isinstance(rpx[key], list) + assert len(rpx[key]) == 3 + rpx1 = rpx[key][0] + assert 'eigenvalues' in rpx1 + assert 'matrix' in rpx1 + assert rpx1['eigenvalues'].shape == (15,) + assert rpx1['matrix'].shape == (15, 15) + + @pytest.mark.parametrize('key', ['solvent', 'vacuum']) + def test_get_replica_states(self, key, protocolresult): + rep = protocolresult.get_replica_states() + + assert isinstance(rep, dict) + assert isinstance(rep[key], list) + assert len(rep[key]) == 3 + assert rep[key][0].shape == (251, 15) + + @pytest.mark.parametrize('key', ['solvent', 'vacuum']) + def test_equilibration_iterations(self, key, protocolresult): + eq = protocolresult.equilibration_iterations() + + assert isinstance(eq, dict) + assert isinstance(eq[key], list) + assert len(eq[key]) == 3 + assert all(isinstance(v, float) for v in eq[key]) + + @pytest.mark.parametrize('key', ['solvent', 'vacuum']) + def test_production_iterations(self, key, protocolresult): + prod = protocolresult.production_iterations() + + assert isinstance(prod, dict) + assert isinstance(prod[key], list) + assert len(prod[key]) == 3 + assert all(isinstance(v, float) for v in prod[key]) diff --git a/openfe/tests/protocols/test_openmm_equil_rfe_protocols.py b/openfe/tests/protocols/test_openmm_equil_rfe_protocols.py index 3445d7b2f..50c3c9f22 100644 --- a/openfe/tests/protocols/test_openmm_equil_rfe_protocols.py +++ b/openfe/tests/protocols/test_openmm_equil_rfe_protocols.py @@ -1243,16 +1243,16 @@ def test_constraints(tyk2_xml, tyk2_reference_xml): class TestProtocolResult: @pytest.fixture() - def protocolresult(self, transformation_json): - d = json.loads(transformation_json, + def protocolresult(self, rfe_transformation_json): + d = json.loads(rfe_transformation_json, cls=gufe.tokenization.JSON_HANDLER.decoder) pr = openfe.ProtocolResult.from_dict(d['protocol_result']) return pr - def test_reload_protocol_result(self, transformation_json): - d = json.loads(transformation_json, + def test_reload_protocol_result(self, rfe_transformation_json): + d = json.loads(rfe_transformation_json, cls=gufe.tokenization.JSON_HANDLER.decoder) pr = openmm_rfe.RelativeHybridTopologyProtocolResult.from_dict(d['protocol_result']) diff --git a/openfe/tests/protocols/test_solvation_afe_tokenization.py b/openfe/tests/protocols/test_solvation_afe_tokenization.py new file mode 100644 index 000000000..576655c2a --- /dev/null +++ b/openfe/tests/protocols/test_solvation_afe_tokenization.py @@ -0,0 +1,93 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +import json +import openfe +from openfe.protocols import openmm_afe +import gufe +from gufe.tests.test_tokenization import GufeTokenizableTestsMixin +import pytest + + +@pytest.fixture +def protocol(): + return openmm_afe.AbsoluteSolvationProtocol( + openmm_afe.AbsoluteSolvationProtocol.default_settings() + ) + + +@pytest.fixture +def protocol_units(protocol, benzene_system): + pus = protocol.create( + stateA=benzene_system, + stateB=openfe.ChemicalSystem({'solvent': openfe.SolventComponent()}), + mapping=None, + ) + return list(pus.protocol_units) + + +@pytest.fixture +def solvent_protocol_unit(protocol_units): + for pu in protocol_units: + if isinstance(pu, openmm_afe.AbsoluteSolvationSolventUnit): + return pu + + +@pytest.fixture +def vacuum_protocol_unit(protocol_units): + for pu in protocol_units: + if isinstance(pu, openmm_afe.AbsoluteSolvationVacuumUnit): + return pu + + +@pytest.fixture +def protocol_result(afe_solv_transformation_json): + d = json.loads(afe_solv_transformation_json, + cls=gufe.tokenization.JSON_HANDLER.decoder) + pr = openmm_afe.AbsoluteSolvationProtocolResult.from_dict(d['protocol_result']) + return pr + + +class TestAbsoluteSolvationProtocol(GufeTokenizableTestsMixin): + cls = openmm_afe.AbsoluteSolvationProtocol + key = "AbsoluteSolvationProtocol-fd22076bcea777207beb86ef7a6ded81" + repr = f"<{key}>" + + @pytest.fixture() + def instance(self, protocol): + return protocol + + +class TestAbsoluteSolvationSolventUnit(GufeTokenizableTestsMixin): + cls = openmm_afe.AbsoluteSolvationSolventUnit + repr = "AbsoluteSolvationSolventUnit(Absolute Solvation, benzene solvent leg: repeat 2 generation 0)" + key = None + + @pytest.fixture() + def instance(self, solvent_protocol_unit): + return solvent_protocol_unit + + def test_key_stable(self): + pytest.skip() + + +class TestAbsoluteSolvationVacuumUnit(GufeTokenizableTestsMixin): + cls = openmm_afe.AbsoluteSolvationVacuumUnit + repr = "AbsoluteSolvationVacuumUnit(Absolute Solvation, benzene vacuum leg: repeat 2 generation 0)" + key = None + + @pytest.fixture() + def instance(self, vacuum_protocol_unit): + return vacuum_protocol_unit + + def test_key_stable(self): + pytest.skip() + + +class TestAbsoluteSolvationProtocolResult(GufeTokenizableTestsMixin): + cls = openmm_afe.AbsoluteSolvationProtocolResult + key = "AbsoluteSolvationProtocolResult-8caab27e7ad1bd544a787ac639f5f447" + repr = f"<{key}>" + + @pytest.fixture() + def instance(self, protocol_result): + return protocol_result