Skip to content

Commit

Permalink
Merge pull request #664 from choderalab/replica-improvements
Browse files Browse the repository at this point in the history
Replica improvements
  • Loading branch information
hannahbrucemacdonald authored Apr 13, 2020
2 parents eb94231 + e86c345 commit 6fd9650
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 19 deletions.
47 changes: 47 additions & 0 deletions perses/analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,50 @@ def get_t0(filename):
if t0 == 0:
_logger.warning(f"t0 for this file is zero, which means that stage 2 was not reached within the simulation")
return t0


def plot_replica_mixing(ncfile_name, title='',filename='replicas.png'):
"""
Plots the path of each replica through the states, with marginal distribution shown
Arguments
--------
ncfile_name : str
path to nc file to analyse
title : str, default=''
Title to add to plot
filename : str, default='replicas.png'
path where to save the output plot
"""
import numpy as np
import matplotlib.pyplot as plt

ncfile = open_netcdf(ncfile_name)

n_iter, n_states = ncfile.variables['states'].shape
cmaps = plt.cm.get_cmap('gist_rainbow')
colours = [cmaps(i) for i in np.linspace(0.,1.,n_states)]
fig, axes = plt.subplots(nrows=n_states, ncols=2, sharex='col',sharey=True,figsize=(15,2*n_states), squeeze=True, gridspec_kw={'width_ratios': [5, 1]})

for rep in range(n_states):
ax = axes[rep,0]
y = ncfile.variables['states'][:,rep]
ax.plot(y,marker='.', linewidth=0, markersize=2,color=colours[rep])
ax.set_xlim(-1,n_iter+1)
hist_plot = axes[rep,1]
hist_plot.hist(y, bins=n_states, orientation='horizontal',histtype='step',color=colours[rep],linewidth=3)

ax.set_ylabel('State')
hist_plot.yaxis.set_label_position("right")
hist_plot.set_ylabel(f'Replica {rep}',rotation=270, labelpad=10)

# just plotting for the bottom plot
if rep == n_states-1:
ax.set_xlabel('Iteration')
hist_plot.set_xlabel('State count')

fig.tight_layout()
plt.title(title)
plt.savefig(filename)

25 changes: 24 additions & 1 deletion perses/annihilation/lambda_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,23 @@ def __init__(self, functions='default'):
All protocols must begin and end at 0 and 1 respectively. Any energy term not defined
in `functions` dict will be set to the function in `default_functions`
Pre-coded options:
default : ele and LJ terms of the old system are turned off between 0.0 -> 0.5
ele and LJ terms of the new system are turned on between 0.5 -> 1.0
core terms treated linearly
quarters : 0.25 of the protocol is used in turn to individually change the
(a) off old ele, (b) off old sterics, (c) on new sterics (d) on new ele
core terms treated linearly
namd : follows the protocol outlined here: https://pubs.acs.org/doi/full/10.1021/acs.jcim.9b00362#
Jiang, Wei, Christophe Chipot, and Benoît Roux. "Computing Relative Binding Affinity of Ligands
to Receptor: An Effective Hybrid Single-Dual-Topology Free-Energy Perturbation Approach in NAMD."
Journal of chemical information and modeling 59.9 (2019): 3794-3802.
ele-scaled : all terms are treated as in default, except for the old and new ele
these are scaled with lambda^0.5, so as to be linear in energy, rather than lambda
Parameters
----------
type : str or dict, default='default'
Expand Down Expand Up @@ -102,6 +119,12 @@ def __init__(self, functions='default'):
lambda x: x,
'lambda_torsions':
lambda x: x}
elif self.type == 'ele-scaled':
self.functions = {'lambda_electrostatics_insert':
lambda x: 0.0 if x < 0.5 else ((2*(x-0.5))**0.5),
'lambda_electrostatics_delete':
lambda x: (2*x)**2 if x < 0.5 else 1.0
}
else:
_logger.warning(f"""LambdaProtocol type : {self.type} not
recognised. Allowed values are 'default',
Expand Down Expand Up @@ -173,7 +196,7 @@ def _check_for_naked_charges(self,n=10):
def get_functions(self):
return self.functions

def plot_fucntions(self,n=50):
def plot_functions(self,n=50):
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(10,5))

Expand Down
5 changes: 2 additions & 3 deletions perses/app/relative_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,6 @@ def __init__(self, ligand_input, old_ligand_index, new_ligand_index, forcefield_
else:
barostat = None
_logger.info(f"omitted MonteCarloBarostat because pressure was specified but system was not periodic")
else:
barostat = None
_logger.info(f"omitted MonteCarloBarostat because pressure was not specified")

Expand Down Expand Up @@ -263,9 +262,9 @@ def __init__(self, ligand_input, old_ligand_index, new_ligand_index, forcefield_

# Create SystemGenerator
from openmmforcefields.generators import SystemGenerator
forcefield_kwargs = {'removeCMMotion': False, 'ewaldErrorTolerance': self._pme_tol, 'nonbondedMethod': self._nonbonded_method,'constraints' : app.HBonds, 'hydrogenMass' : self._hmass}
forcefield_kwargs = {'removeCMMotion': False, 'ewaldErrorTolerance': self._pme_tol, 'constraints' : app.HBonds, 'hydrogenMass' : self._hmass}
self._system_generator = SystemGenerator(forcefields=forcefield_files, barostat=barostat, forcefield_kwargs=forcefield_kwargs,
small_molecule_forcefield=small_molecule_forcefield, molecules=molecules, cache=small_molecule_parameters_cache)
small_molecule_forcefield=small_molecule_forcefield, molecules=molecules, cache=small_molecule_parameters_cache, periodic_forcefield_kwargs = {'nonbondedMethod': self._nonbonded_method})
_logger.info("successfully created SystemGenerator to create ligand systems")

_logger.info(f"executing SmallMoleculeSetProposalEngine...")
Expand Down
36 changes: 29 additions & 7 deletions perses/app/setup_relative_calculation.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,10 @@ def run_setup(setup_options):
_logger.info(f"\tno nonequilibrium detected.")
n_states = setup_options['n_states']
_logger.info(f"\tn_states: {n_states}")
n_replicas = setup_options['n_replicas']
if 'n_replicas' not in setup_options:
n_replicas = n_states
else:
n_replicas = setup_options['n_replicas']
_logger.info(f"\tn_replicas: {n_replicas}")
checkpoint_interval = setup_options['checkpoint_interval']
_logger.info(f"\tcheckpoint_interval: {checkpoint_interval}")
Expand All @@ -495,12 +498,17 @@ def run_setup(setup_options):
_forward_added_valence_energy = top_prop['%s_added_valence_energy' % phase]
_reverse_subtracted_valence_energy = top_prop['%s_subtracted_valence_energy' % phase]

zero_state_error, one_state_error = validate_endstate_energies(_top_prop, _htf, _forward_added_valence_energy, _reverse_subtracted_valence_energy, beta = 1.0/(kB*temperature), ENERGY_THRESHOLD = ENERGY_THRESHOLD)
xml_directory = f'{setup_options["trajectory_directory"]}/xml/'
if not os.path.exists(xml_directory):
os.makedirs(xml_directory)

zero_state_error, one_state_error = validate_endstate_energies(_top_prop, _htf, _forward_added_valence_energy, _reverse_subtracted_valence_energy, beta = 1.0/(kB*temperature), ENERGY_THRESHOLD = ENERGY_THRESHOLD, trajectory_directory=f'{xml_directory}{phase}')
_logger.info(f"\t\terror in zero state: {zero_state_error}")
_logger.info(f"\t\terror in one state: {one_state_error}")

# generating lambda protocol
lambda_protocol = LambdaProtocol(functions=setup_options['protocol-type'])
_logger.info(f'Using lambda protocol : {setup_options["protocol-type"]}')


if atom_selection:
Expand All @@ -521,24 +529,38 @@ def run_setup(setup_options):
endstates = True
#TODO expose more of these options in input
if setup_options['fe_type'] == 'sams':
hss[phase] = HybridSAMSSampler(mcmc_moves=mcmc.LangevinDynamicsMove(timestep=timestep,
hss[phase] = HybridSAMSSampler(mcmc_moves=mcmc.LangevinSplittingDynamicsMove(timestep=timestep,
collision_rate=1.0 / unit.picosecond,
n_steps=n_steps_per_move_application,
reassign_velocities=False,
n_restart_attempts=20),
n_restart_attempts=20,constraint_tolerance=1e-06),
hybrid_factory=htf[phase], online_analysis_interval=setup_options['offline-freq'],
online_analysis_minimum_iterations=10,flatness_criteria=setup_options['flatness-criteria'],
gamma0=setup_options['gamma0'])
hss[phase].setup(n_states=n_states, n_replicas=n_replicas, temperature=temperature,storage_file=reporter,lambda_protocol=lambda_protocol,endstates=endstates)
elif setup_options['fe_type'] == 'repex':
hss[phase] = HybridRepexSampler(mcmc_moves=mcmc.LangevinDynamicsMove(timestep=timestep,
hss[phase] = HybridRepexSampler(mcmc_moves=mcmc.LangevinSplittingDynamicsMove(timestep=timestep,
collision_rate=1.0 / unit.picosecond,
n_steps=n_steps_per_move_application,
reassign_velocities=False,
n_restart_attempts=20),
n_restart_attempts=20,constraint_tolerance=1e-06),
hybrid_factory=htf[phase],online_analysis_interval=setup_options['offline-freq'])
hss[phase].setup(n_states=n_states, temperature=temperature,storage_file=reporter,lambda_protocol=lambda_protocol,endstates=endstates)

# save the systems and the states
from simtk.openmm import XmlSerializer
from perses.tests.utils import generate_endpoint_thermodynamic_states

_logger.info('WRITING OUT XML FILES')
#old_thermodynamic_state, new_thermodynamic_state, hybrid_thermodynamic_state, _ = generate_endpoint_thermodynamic_states(htf[phase].hybrid_system, _top_prop)


from perses.utils import data
_logger.info(f'Saving the hybrid, old and new system to disk')
data.serialize(htf[phase].hybrid_system, f'{setup_options["trajectory_directory"]}/xml/{phase}-hybrid-system.gz')
data.serialize(htf[phase]._old_system, f'{setup_options["trajectory_directory"]}/xml/{phase}-old-system.gz')
data.serialize(htf[phase]._new_system, f'{setup_options["trajectory_directory"]}/xml/{phase}-new-system.gz')

return {'topology_proposals': top_prop, 'hybrid_topology_factories': htf, 'hybrid_samplers': hss}


Expand Down Expand Up @@ -629,7 +651,7 @@ def run(yaml_filename=None):
_forward_added_valence_energy = setup_dict['topology_proposals'][f"{phase}_added_valence_energy"]
_reverse_subtracted_valence_energy = setup_dict['topology_proposals'][f"{phase}_subtracted_valence_energy"]

zero_state_error, one_state_error = validate_endstate_energies(hybrid_factory._topology_proposal, hybrid_factory, _forward_added_valence_energy, _reverse_subtracted_valence_energy, beta = 1.0/(kB*temperature), ENERGY_THRESHOLD = ENERGY_THRESHOLD)
zero_state_error, one_state_error = validate_endstate_energies(hybrid_factory._topology_proposal, hybrid_factory, _forward_added_valence_energy, _reverse_subtracted_valence_energy, beta = 1.0/(kB*temperature), ENERGY_THRESHOLD = ENERGY_THRESHOLD, trajectory_directory=f'{setup_options["trajectory_directory"]}/xml/{phase}')
_logger.info(f"\t\terror in zero state: {zero_state_error}")
_logger.info(f"\t\terror in one state: {one_state_error}")

Expand Down
18 changes: 11 additions & 7 deletions perses/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,7 @@ def validate_rjmc_work_variance(top_prop, positions, geometry_method = 0, num_it

return conformers, rj_works

def validate_endstate_energies(topology_proposal, htf, added_energy, subtracted_energy, beta = 1.0/kT, ENERGY_THRESHOLD = 1e-6, platform = DEFAULT_PLATFORM):
def validate_endstate_energies(topology_proposal, htf, added_energy, subtracted_energy, beta = 1.0/kT, ENERGY_THRESHOLD = 1e-6, platform = DEFAULT_PLATFORM, trajectory_directory=None):
"""
Function to validate that the difference between the nonalchemical versus alchemical state at lambda = 0,1 is
equal to the difference in valence energy (forward and reverse).
Expand All @@ -746,6 +746,7 @@ def validate_endstate_energies(topology_proposal, htf, added_energy, subtracted_
#import openmmtools.cache as cache
#context_cache = cache.global_context_cache
from perses.dispersed.utils import configure_platform
from perses.utils import data
platform = configure_platform(platform.getName(), fallback_platform_name='Reference', precision='double')

#create copies of old/new systems and set the dispersion correction
Expand All @@ -770,14 +771,13 @@ def validate_endstate_energies(topology_proposal, htf, added_energy, subtracted_

# compute reduced energies
#for the nonalchemical systems...
attrib_list = [(nonalch_zero, old_positions, top_proposal._old_system.getDefaultPeriodicBoxVectors()),
(alch_zero, htf._hybrid_positions, hybrid_system.getDefaultPeriodicBoxVectors()),
(alch_one, htf._hybrid_positions, hybrid_system.getDefaultPeriodicBoxVectors()),
(nonalch_one, new_positions, top_proposal._new_system.getDefaultPeriodicBoxVectors())]
attrib_list = [('real-old',nonalch_zero, old_positions, top_proposal._old_system.getDefaultPeriodicBoxVectors()),
('hybrid-old',alch_zero, htf._hybrid_positions, hybrid_system.getDefaultPeriodicBoxVectors()),
('hybrid-new',alch_one, htf._hybrid_positions, hybrid_system.getDefaultPeriodicBoxVectors()),
('real-new',nonalch_one, new_positions, top_proposal._new_system.getDefaultPeriodicBoxVectors())]

rp_list = []
for (state, pos, box_vectors) in attrib_list:
#print("\t\t\t{}".format(state))
for (state_name, state, pos, box_vectors) in attrib_list:
integrator = openmm.VerletIntegrator(1.0 * unit.femtoseconds)
context = state.create_context(integrator, platform)
samplerstate = states.SamplerState(positions = pos, box_vectors = box_vectors)
Expand All @@ -789,6 +789,10 @@ def validate_endstate_energies(topology_proposal, htf, added_energy, subtracted_
print("\t\t\t{}: {}".format(name, force))
_logger.debug(f'added forces:{sum([energy for name, energy in energy_comps])}')
_logger.debug(f'rp: {rp}')
if trajectory_directory is not None:
_logger.info(f'Saving {state_name} state xml to {trajectory_directory}/{state_name}-state.gz')
state = context.getState(getPositions=True, getVelocities=True, getForces=True, getEnergy=True, getParameters=True)
data.serialize(state,f'{trajectory_directory}-{state_name}-state.gz')
del context, integrator

nonalch_zero_rp, alch_zero_rp, alch_one_rp, nonalch_one_rp = rp_list[0], rp_list[1], rp_list[2], rp_list[3]
Expand Down
24 changes: 23 additions & 1 deletion perses/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,26 @@ def load_smi(smi_file,index=None):
return smiless
else:
smiles = smiless[index]
return smiles
return smiles


def serialize(item, filename):
"""
Serialize an OpenMM System, State, or Integrator.
Parameters
----------
item : System, State, or Integrator
The thing to be serialized
filename : str
The filename to serialize to
"""
from simtk.openmm import XmlSerializer
if filename[-2:] == 'gz':
import gzip
with gzip.open(filename, 'wb') as outfile:
serialized_thing = XmlSerializer.serialize(item)
outfile.write(serialized_thing.encode())
else:
with open(filename, 'w') as outfile:
serialized_thing = XmlSerializer.serialize(item)
outfile.write(serialized_thing)

0 comments on commit 6fd9650

Please sign in to comment.