Skip to content

Commit

Permalink
Merge pull request #700 from OpenFreeEnergy/protocol-to-units
Browse files Browse the repository at this point in the history
Switch settings input to protocol units to creating protocol itself.
  • Loading branch information
richardjgowers authored Feb 7, 2024
2 parents 0219070 + 48b31c4 commit 00ef4de
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 45 deletions.
12 changes: 5 additions & 7 deletions openfe/protocols/openmm_afe/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,27 +70,25 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit):
Base class for ligand absolute free energy transformations.
"""
def __init__(self, *,
protocol: gufe.Protocol,
stateA: ChemicalSystem,
stateB: ChemicalSystem,
settings: settings.Settings,
alchemical_components: dict[str, list[Component]],
generation: int = 0,
repeat_id: int = 0,
name: Optional[str] = None,):
"""
Parameters
----------
protocol : gufe.Protocol
protocol used to create this Unit. Contains key information such
as the settings.
stateA : ChemicalSystem
ChemicalSystem containing the components defining the state at
lambda 0.
stateB : ChemicalSystem
ChemicalSystem containing the components defining the state at
lambda 1.
settings : gufe.settings.Setings
Settings for the Absolute Tranformation Protocol. This can be
constructed by calling the
:class:`AbsoluteTransformProtocol.get_default_settings` method
to get a default set of settings.
alchemical_components : dict[str, Component]
the alchemical components for each state in this Unit
name : str, optional
Expand All @@ -104,9 +102,9 @@ def __init__(self, *,
"""
super().__init__(
name=name,
protocol=protocol,
stateA=stateA,
stateB=stateB,
settings=settings,
alchemical_components=alchemical_components,
repeat_id=repeat_id,
generation=generation,
Expand Down
14 changes: 8 additions & 6 deletions openfe/protocols/openmm_afe/equil_solvation_afe_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,8 +641,9 @@ def _create(

solvent_units = [
AbsoluteSolvationSolventUnit(
stateA=stateA, stateB=stateB,
settings=self.settings,
protocol=self,
stateA=stateA,
stateB=stateB,
alchemical_components=alchem_comps,
generation=0, repeat_id=int(uuid.uuid4()),
name=(f"Absolute Solvation, {alchname} solvent leg: "
Expand All @@ -655,8 +656,9 @@ def _create(
AbsoluteSolvationVacuumUnit(
# These don't really reflect the actual transform
# Should these be overriden to be ChemicalSystem{smc} -> ChemicalSystem{} ?
stateA=stateA, stateB=stateB,
settings=self.settings,
protocol=self,
stateA=stateA,
stateB=stateB,
alchemical_components=alchem_comps,
generation=0, repeat_id=int(uuid.uuid4()),
name=(f"Absolute Solvation, {alchname} vacuum leg: "
Expand Down Expand Up @@ -747,7 +749,7 @@ def _handle_settings(self) -> dict[str, SettingsBaseModel]:
* simulation_settings : SimulationSettings
* output_settings: OutputSettings
"""
prot_settings = self._inputs['settings']
prot_settings = self._inputs['protocol'].settings

settings = {}
settings['forcefield_settings'] = prot_settings.vacuum_forcefield_settings
Expand Down Expand Up @@ -831,7 +833,7 @@ def _handle_settings(self) -> dict[str, SettingsBaseModel]:
* simulation_settings : MultiStateSimulationSettings
* output_settings: OutputSettings
"""
prot_settings = self._inputs['settings']
prot_settings = self._inputs['protocol'].settings

settings = {}
settings['forcefield_settings'] = prot_settings.solvent_forcefield_settings
Expand Down
32 changes: 16 additions & 16 deletions openfe/protocols/openmm_md/plain_md_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,16 +169,16 @@ def _create(
# our DAG has no dependencies, so just list units
n_repeats = self.settings.protocol_repeats
units = [PlainMDProtocolUnit(
protocol=self,
stateA=stateA,
settings=self.settings,
generation=0, repeat_id=int(uuid.uuid4()),
name=f'{system_name} repeat {i} generation 0')
for i in range(n_repeats)]

return units

def _gather(
self, protocol_dag_results: Iterable[gufe.ProtocolDAGResult]
self, protocol_dag_results: Iterable[gufe.ProtocolDAGResult]
) -> dict[str, Any]:
# result units will have a repeat_id and generations within this
# repeat_id
Expand Down Expand Up @@ -206,22 +206,23 @@ class PlainMDProtocolUnit(gufe.ProtocolUnit):
Protocol unit for plain MD simulations (NonTransformation).
"""

def __init__(self, *,
stateA: ChemicalSystem,
settings: PlainMDProtocolSettings,
generation: int,
repeat_id: int,
name: Optional[str] = None,
):
def __init__(
self,
*,
protocol: PlainMDProtocol,
stateA: ChemicalSystem,
generation: int,
repeat_id: int,
name: Optional[str] = None,
):
"""
Parameters
----------
protocol : PlainMDProtocol
protocol used to create this Unit. Contains key information such
as the settings.
stateA : ChemicalSystem
the chemical system for the MD simulation
settings : settings.Settings
the settings for the Method. This can be constructed using the
get_default_settings classmethod to give a starting point that
can be updated to suit.
repeat_id : int
identifier for which repeat (aka replica/clone) this Unit is
generation : int
Expand All @@ -236,8 +237,8 @@ def __init__(self, *,
"""
super().__init__(
name=name,
protocol=protocol,
stateA=stateA,
settings=settings,
repeat_id=repeat_id,
generation=generation
)
Expand Down Expand Up @@ -457,8 +458,7 @@ def run(self, *, dry=False, verbose=True,
# 0. General setup and settings dependency resolution step

# Extract relevant settings
protocol_settings: PlainMDProtocolSettings = self._inputs[
'settings']
protocol_settings: PlainMDProtocolSettings = self._inputs['protocol'].settings
stateA = self._inputs['stateA']

forcefield_settings: settings.OpenMMSystemGeneratorFFSettings = protocol_settings.forcefield_settings
Expand Down
33 changes: 17 additions & 16 deletions openfe/protocols/openmm_rfe/equil_rfe_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,8 +499,8 @@ def _create(
# our DAG has no dependencies, so just list units
n_repeats = self.settings.protocol_repeats
units = [RelativeHybridTopologyProtocolUnit(
protocol=self,
stateA=stateA, stateB=stateB, ligandmapping=ligandmapping,
settings=self.settings,
generation=0, repeat_id=int(uuid.uuid4()),
name=f'{Anames} to {Bnames} repeat {i} generation 0')
for i in range(n_repeats)]
Expand Down Expand Up @@ -535,27 +535,28 @@ class RelativeHybridTopologyProtocolUnit(gufe.ProtocolUnit):
Calculates the relative free energy of an alchemical ligand transformation.
"""

def __init__(self, *,
stateA: ChemicalSystem,
stateB: ChemicalSystem,
ligandmapping: LigandAtomMapping,
settings: RelativeHybridTopologyProtocolSettings,
generation: int,
repeat_id: int,
name: Optional[str] = None,
):
def __init__(
self,
*,
protocol: RelativeHybridTopologyProtocol,
stateA: ChemicalSystem,
stateB: ChemicalSystem,
ligandmapping: LigandAtomMapping,
generation: int,
repeat_id: int,
name: Optional[str] = None,
):
"""
Parameters
----------
protocol : RelativeHybridTopologyProtocol
protocol used to create this Unit. Contains key information such
as the settings.
stateA, stateB : ChemicalSystem
the two ligand SmallMoleculeComponents to transform between. The
transformation will go from ligandA to ligandB.
ligandmapping : LigandAtomMapping
the mapping of atoms between the two ligand components
settings : settings.Settings
the settings for the Method. This can be constructed using the
get_default_settings classmethod to give a starting point that
can be updated to suit.
repeat_id : int
identifier for which repeat (aka replica/clone) this Unit is
generation : int
Expand All @@ -570,10 +571,10 @@ def __init__(self, *,
"""
super().__init__(
name=name,
protocol=protocol,
stateA=stateA,
stateB=stateB,
ligandmapping=ligandmapping,
settings=settings,
repeat_id=repeat_id,
generation=generation
)
Expand Down Expand Up @@ -619,7 +620,7 @@ def run(self, *, dry=False, verbose=True,
# 0. General setup and settings dependency resolution step

# Extract relevant settings
protocol_settings: RelativeHybridTopologyProtocolSettings = self._inputs['settings']
protocol_settings: RelativeHybridTopologyProtocolSettings = self._inputs['protocol'].settings
stateA = self._inputs['stateA']
stateB = self._inputs['stateB']
mapping = self._inputs['ligandmapping']
Expand Down

0 comments on commit 00ef4de

Please sign in to comment.