diff --git a/openfe/protocols/openmm_afe/base.py b/openfe/protocols/openmm_afe/base.py index 67d03f7c5..d1cb29bb2 100644 --- a/openfe/protocols/openmm_afe/base.py +++ b/openfe/protocols/openmm_afe/base.py @@ -70,9 +70,9 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): Base class for ligand absolute free energy transformations. """ def __init__(self, *, + protocol: gufe.Protocol, stateA: ChemicalSystem, stateB: ChemicalSystem, - settings: settings.Settings, alchemical_components: dict[str, list[Component]], generation: int = 0, repeat_id: int = 0, @@ -80,17 +80,15 @@ def __init__(self, *, """ Parameters ---------- + protocol : gufe.Protocol + protocol used to create this Unit. Contains key information such + as the settings. stateA : ChemicalSystem ChemicalSystem containing the components defining the state at lambda 0. stateB : ChemicalSystem ChemicalSystem containing the components defining the state at lambda 1. - settings : gufe.settings.Setings - Settings for the Absolute Tranformation Protocol. This can be - constructed by calling the - :class:`AbsoluteTransformProtocol.get_default_settings` method - to get a default set of settings. alchemical_components : dict[str, Component] the alchemical components for each state in this Unit name : str, optional @@ -104,9 +102,9 @@ def __init__(self, *, """ super().__init__( name=name, + protocol=protocol, stateA=stateA, stateB=stateB, - settings=settings, alchemical_components=alchemical_components, repeat_id=repeat_id, generation=generation, diff --git a/openfe/protocols/openmm_afe/equil_solvation_afe_method.py b/openfe/protocols/openmm_afe/equil_solvation_afe_method.py index 6e18f2252..f482592a4 100644 --- a/openfe/protocols/openmm_afe/equil_solvation_afe_method.py +++ b/openfe/protocols/openmm_afe/equil_solvation_afe_method.py @@ -641,8 +641,9 @@ def _create( solvent_units = [ AbsoluteSolvationSolventUnit( - stateA=stateA, stateB=stateB, - settings=self.settings, + protocol=self, + stateA=stateA, + stateB=stateB, alchemical_components=alchem_comps, generation=0, repeat_id=int(uuid.uuid4()), name=(f"Absolute Solvation, {alchname} solvent leg: " @@ -655,8 +656,9 @@ def _create( AbsoluteSolvationVacuumUnit( # These don't really reflect the actual transform # Should these be overriden to be ChemicalSystem{smc} -> ChemicalSystem{} ? - stateA=stateA, stateB=stateB, - settings=self.settings, + protocol=self, + stateA=stateA, + stateB=stateB, alchemical_components=alchem_comps, generation=0, repeat_id=int(uuid.uuid4()), name=(f"Absolute Solvation, {alchname} vacuum leg: " @@ -747,7 +749,7 @@ def _handle_settings(self) -> dict[str, SettingsBaseModel]: * simulation_settings : SimulationSettings * output_settings: OutputSettings """ - prot_settings = self._inputs['settings'] + prot_settings = self._inputs['protocol'].settings settings = {} settings['forcefield_settings'] = prot_settings.vacuum_forcefield_settings @@ -831,7 +833,7 @@ def _handle_settings(self) -> dict[str, SettingsBaseModel]: * simulation_settings : MultiStateSimulationSettings * output_settings: OutputSettings """ - prot_settings = self._inputs['settings'] + prot_settings = self._inputs['protocol'].settings settings = {} settings['forcefield_settings'] = prot_settings.solvent_forcefield_settings diff --git a/openfe/protocols/openmm_md/plain_md_methods.py b/openfe/protocols/openmm_md/plain_md_methods.py index 2716d468c..9aa96ba2a 100644 --- a/openfe/protocols/openmm_md/plain_md_methods.py +++ b/openfe/protocols/openmm_md/plain_md_methods.py @@ -169,8 +169,8 @@ def _create( # our DAG has no dependencies, so just list units n_repeats = self.settings.protocol_repeats units = [PlainMDProtocolUnit( + protocol=self, stateA=stateA, - settings=self.settings, generation=0, repeat_id=int(uuid.uuid4()), name=f'{system_name} repeat {i} generation 0') for i in range(n_repeats)] @@ -178,7 +178,7 @@ def _create( return units def _gather( - self, protocol_dag_results: Iterable[gufe.ProtocolDAGResult] + self, protocol_dag_results: Iterable[gufe.ProtocolDAGResult] ) -> dict[str, Any]: # result units will have a repeat_id and generations within this # repeat_id @@ -206,22 +206,23 @@ class PlainMDProtocolUnit(gufe.ProtocolUnit): Protocol unit for plain MD simulations (NonTransformation). """ - def __init__(self, *, - stateA: ChemicalSystem, - settings: PlainMDProtocolSettings, - generation: int, - repeat_id: int, - name: Optional[str] = None, - ): + def __init__( + self, + *, + protocol: PlainMDProtocol, + stateA: ChemicalSystem, + generation: int, + repeat_id: int, + name: Optional[str] = None, + ): """ Parameters ---------- + protocol : PlainMDProtocol + protocol used to create this Unit. Contains key information such + as the settings. stateA : ChemicalSystem the chemical system for the MD simulation - settings : settings.Settings - the settings for the Method. This can be constructed using the - get_default_settings classmethod to give a starting point that - can be updated to suit. repeat_id : int identifier for which repeat (aka replica/clone) this Unit is generation : int @@ -236,8 +237,8 @@ def __init__(self, *, """ super().__init__( name=name, + protocol=protocol, stateA=stateA, - settings=settings, repeat_id=repeat_id, generation=generation ) @@ -457,8 +458,7 @@ def run(self, *, dry=False, verbose=True, # 0. General setup and settings dependency resolution step # Extract relevant settings - protocol_settings: PlainMDProtocolSettings = self._inputs[ - 'settings'] + protocol_settings: PlainMDProtocolSettings = self._inputs['protocol'].settings stateA = self._inputs['stateA'] forcefield_settings: settings.OpenMMSystemGeneratorFFSettings = protocol_settings.forcefield_settings diff --git a/openfe/protocols/openmm_rfe/equil_rfe_methods.py b/openfe/protocols/openmm_rfe/equil_rfe_methods.py index b362d6b79..e1c7bb3eb 100644 --- a/openfe/protocols/openmm_rfe/equil_rfe_methods.py +++ b/openfe/protocols/openmm_rfe/equil_rfe_methods.py @@ -499,8 +499,8 @@ def _create( # our DAG has no dependencies, so just list units n_repeats = self.settings.protocol_repeats units = [RelativeHybridTopologyProtocolUnit( + protocol=self, stateA=stateA, stateB=stateB, ligandmapping=ligandmapping, - settings=self.settings, generation=0, repeat_id=int(uuid.uuid4()), name=f'{Anames} to {Bnames} repeat {i} generation 0') for i in range(n_repeats)] @@ -535,27 +535,28 @@ class RelativeHybridTopologyProtocolUnit(gufe.ProtocolUnit): Calculates the relative free energy of an alchemical ligand transformation. """ - def __init__(self, *, - stateA: ChemicalSystem, - stateB: ChemicalSystem, - ligandmapping: LigandAtomMapping, - settings: RelativeHybridTopologyProtocolSettings, - generation: int, - repeat_id: int, - name: Optional[str] = None, - ): + def __init__( + self, + *, + protocol: RelativeHybridTopologyProtocol, + stateA: ChemicalSystem, + stateB: ChemicalSystem, + ligandmapping: LigandAtomMapping, + generation: int, + repeat_id: int, + name: Optional[str] = None, + ): """ Parameters ---------- + protocol : RelativeHybridTopologyProtocol + protocol used to create this Unit. Contains key information such + as the settings. stateA, stateB : ChemicalSystem the two ligand SmallMoleculeComponents to transform between. The transformation will go from ligandA to ligandB. ligandmapping : LigandAtomMapping the mapping of atoms between the two ligand components - settings : settings.Settings - the settings for the Method. This can be constructed using the - get_default_settings classmethod to give a starting point that - can be updated to suit. repeat_id : int identifier for which repeat (aka replica/clone) this Unit is generation : int @@ -570,10 +571,10 @@ def __init__(self, *, """ super().__init__( name=name, + protocol=protocol, stateA=stateA, stateB=stateB, ligandmapping=ligandmapping, - settings=settings, repeat_id=repeat_id, generation=generation ) @@ -619,7 +620,7 @@ def run(self, *, dry=False, verbose=True, # 0. General setup and settings dependency resolution step # Extract relevant settings - protocol_settings: RelativeHybridTopologyProtocolSettings = self._inputs['settings'] + protocol_settings: RelativeHybridTopologyProtocolSettings = self._inputs['protocol'].settings stateA = self._inputs['stateA'] stateB = self._inputs['stateB'] mapping = self._inputs['ligandmapping']