Skip to content

Commit

Permalink
Add in the remaining protocol unit result methods
Browse files Browse the repository at this point in the history
  • Loading branch information
IAlibay committed Oct 16, 2023
1 parent 85eee9b commit c50bf90
Showing 1 changed file with 153 additions and 16 deletions.
169 changes: 153 additions & 16 deletions openfe/protocols/openmm_afe/equil_solvation_afe_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,15 @@ def __init__(self, **data):

def get_individual_estimates(self) -> dict[str, list[tuple[unit.Quantity, unit.Quantity]]]:
"""
Return a dictionary (keyed as `solvent` and `vacuum`) with list
of tuples containing the individual free energy estimates and
associated MBAR errors for each repeat of the vacuum and solvent
calculations.
Get the individual estimate of the free energies.
Returns
-------
dGs : dict[str, list[tuple[unit.Quantity, unit.Quantity]]]
A dictionary, keyed `solvent` and `vacuum` for each leg
of the thermodynamic cycle, with lists of tuples containing
the individual free energy estimates and associated MBAR
uncertainties for each repeat of that simulation type.
"""
vac_dGs = []
solv_dGs = []
Expand Down Expand Up @@ -155,22 +156,24 @@ def _get_stdev(estimates):

def get_forward_and_reverse_analysis(self) -> dict[str, list[dict[str, Union[npt.NDArray, unit.Quantity]]]]:
"""
Get a dictionary (keyed `solvent` and `vacuum`) with lists of the
reverse and forward analysis of the free energies for each repeat
of the vacuum and solvent calculations using uncorrelated production
samples.
The returned forward and reverse analysis dictionaries have keys:
'fractions' - the fraction of data used for this estimate
'forward_DGs', 'reverse_DGs' - for each fraction of data, the estimate
'forward_dDGs', 'reverse_dDGs' - for each estimate, the uncertainty
The 'fractions' values are a numpy array, while the other arrays are
Quantity arrays, with units attached.
Get the reverse and forward analysis of the free energies.
Returns
-------
forward_reverse : dict[str, list[dict[str, Union[npt.NDArray, unit.Quantity]]]]
A dictionary, keyed `solvent` and `vacuum` for each leg of the
thermodynamic cycle which each contain a list of dictionaries
containing the forward and reverse analysis of each repeat
of that simulation type.
The forward and reverse analysis dictionaries contain:
- `fractions`: npt.NDArray
The fractions of data used for the estimates
- `forward_DGs`, `reverse_DGs`: unit.Quantity
The forward and reverse estimates for each fraction of data
- `forward_dDGs`, `reverse_dDGs`: unit.Quantity
The forward and reverse estimate uncertainty for each
fraction of data.
"""

forward_reverse = {}
Expand All @@ -183,6 +186,140 @@ def get_forward_and_reverse_analysis(self) -> dict[str, list[dict[str, Union[npt

return forward_reverse

def get_overlap_matrices(self) -> dict[str, list[dict[str, npt.NDArray]]]:
"""
Get a the MBAR overlap estimates for all legs of the simulation.
Returns
-------
overlap_stats : dict[str, list[dict[str, npt.NDArray]]]
A dictionary with keys `solvent` and `vacuum` for each
leg of the thermodynamic cycle, which each containing a
list of dictionaries with the MBAR overlap estimates of
each repeat of that simulation type.
The underlying MBAR dictionaries contain the following keys:
* ``scalar``: One minus the largest nontrivial eigenvalue
* ``eigenvalues``: The sorted (descending) eigenvalues of the
overlap matrix
* ``matrix``: Estimated overlap matrix of observing a sample from
state i in state j
"""
# Loop through and get the repeats and get the matrices
overlap_stats = {}

for key in ['solvent', 'vacuum']:
overlap_stats[key] = [
pus[0].outputs['unit_mbar_overlap']
for pus in self.data[key].values()
]

return overlap_stats

def get_replica_transition_statistics(self) -> dict[str, list[dict[str, npt.NDArray]]]:
"""
Get the replica exchange transition statistics for all
legs of the simulation.
Note
----
This is currently only available in cases where a replica exchange
simulation was run.
Returns
-------
repex_stats : dict[str, list[dict[str, npt.NDArray]]]
A dictionary with keys `solvent` and `vacuum` for each
leg of the thermodynamic cycle, which each containing
a list of dictionaries containing the replica transition
statistics for each repeat of that simulation type.
The replica transition statistics dictionaries contain the following:
* ``eigenvalues``: The sorted (descending) eigenvalues of the
lambda state transition matrix
* ``matrix``: The transition matrix estimate of a replica switching
from state i to state j.
"""
repex_stats = {}
try:
for key in ['solvent', 'vacuum']:
repex_stats[key] = [
pus[0].outputs['replica_exchange_statistics']
for pus in self.data[key].values()
]
except KeyError:
errmsg = ("Replica exchange statistics were not found, "
"did you run a repex calculation?")
raise ValueError(errmsg)

return repex_stats

def get_replica_states(self) -> dict[str, list[npt.NDArray]]:
"""
Get the timeseries of replica states for all simulation legs.
Returns
-------
replica_states : dict[str, list[npt.NDArray]]
Dictionary keyed `solvent` and `vacuum` for each leg of
the thermodynamic cycle, with lists of replica states
timeseries for each repeat of that simulation type.
"""
replica_states = {}

for key in ['solvent', 'vacuum']:
replicate_states[key] = [
pus[0].output['replica_states']
for pus in self.data[key].values()
]
return replica_states

def equilibration_iterations(self) -> dict[str, list[float]]:
"""
Get the number of equilibration iterations for each simulation.
Returns
-------
equilibration_lengths : dict[str, list[float]]
Dictionary keyed `solvent` and `vacuum` for each leg
of the thermodynamic cycle, with lists containing the
number of equilibration iterations for each repeat
of that simulation type.
"""
equilibration_lengths = {}

for key in ['solvent', 'vacuum']:
equilibration_lengths[key] = [
pus[0].output['equilibration_iterations']
for pus in self.data[key].values()
]

return equilibration_lengths

def production_iterations(self) -> dict[str, list[float]]:
"""
Get the number of production iterations for each simulation.
Returns the number of uncorrelated production samples for each
repeat of the calculation.
Returns
-------
production_lengths : dict[str, list[float]]
Dictionary keyed `solvent` and `vacuum` for each leg of the
thermodynamic cycle, with lists with the number
of production iterations for each repeat of that simulation
type.
"""a
production_lengths = {}

for key in ['solvent', 'vacuum']:
equilibration_lengths[key] = [
pus[0].output['production_iterations']
for pus in self.data[key].values()
]

return production_lengths


class AbsoluteSolvationProtocol(gufe.Protocol):
"""
Expand Down

0 comments on commit c50bf90

Please sign in to comment.