Skip to content

Commit

Permalink
Support generating data to compute Gtransfer (#123)
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonBoothroyd authored Nov 10, 2024
1 parent a1bcf93 commit 5d1a8da
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 50 deletions.
127 changes: 78 additions & 49 deletions smee/mm/_fe.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,29 +48,50 @@ def _extract_pure_solvent(
return system, xyz, box, beta, pressure, energy


def _to_solvent_dict(
solvent: smee.TensorTopology, n_solvent: int
) -> dict[str, int] | None:
if solvent is None:
return None

smiles = openff.toolkit.Molecule.from_rdkit(
smee.mm._utils.topology_to_rdkit(solvent),
allow_undefined_stereo=True,
).to_smiles(mapped=True)

return {smiles: n_solvent}


def generate_dg_solv_data(
solute: smee.TensorTopology,
solvent: smee.TensorTopology,
solvent_a: smee.TensorTopology | None,
solvent_b: smee.TensorTopology | None,
force_field: smee.TensorForceField,
temperature: openmm.unit.Quantity = 298.15 * openmm.unit.kelvin,
pressure: openmm.unit.Quantity = 1.0 * openmm.unit.atmosphere,
vacuum_protocol: typing.Optional["absolv.config.EquilibriumProtocol"] = None,
solvent_protocol: typing.Optional["absolv.config.EquilibriumProtocol"] = None,
n_solvent: int = 216,
solvent_a_protocol: typing.Optional["absolv.config.EquilibriumProtocol"] = None,
solvent_b_protocol: typing.Optional["absolv.config.EquilibriumProtocol"] = None,
n_solvent_a: int = 216,
n_solvent_b: int = 216,
output_dir: pathlib.Path | None = None,
):
"""Run a solvation free energy calculation using ``absolv``, and saves the output
such that a differentiable free energy can be computed.
The free energy will correspond to the free energy of transferring a solute from
solvent A to solvent B.
Args:
solute: The solute topology.
solvent: The solvent topology.
solvent_a: The topology of solvent A, or ``None`` if solvent A is vacuum.
solvent_b: The topology of solvent B, or ``None`` if solvent B is vacuum.
force_field: The force field to parameterize the system with.
temperature: The temperature to simulate at.
pressure: The pressure to simulate at.
vacuum_protocol: The protocol to use for the vacuum phase.
solvent_protocol: The protocol to use for the solvent phase.
n_solvent: The number of solvent molecules to use.
solvent_a_protocol: The protocol to use to decouple the solute in solvent A.
solvent_b_protocol: The protocol to use to decouple the solute in solvent B.
n_solvent_a: The number of solvent A molecules to use.
n_solvent_b: The number of solvent B molecules to use.
output_dir: The directory to write the output FEP data to.
"""
import absolv.config
Expand All @@ -80,65 +101,69 @@ def generate_dg_solv_data(

output_dir = pathlib.Path.cwd() if output_dir is None else output_dir

if vacuum_protocol is None:
vacuum_protocol = absolv.config.EquilibriumProtocol(
production_protocol=absolv.config.HREMDProtocol(
n_steps_per_cycle=500,
n_cycles=2000,
integrator=femto.md.config.LangevinIntegrator(
timestep=1.0 * openmm.unit.femtosecond
),
trajectory_interval=1,
vacuum_protocol = absolv.config.EquilibriumProtocol(
production_protocol=absolv.config.HREMDProtocol(
n_steps_per_cycle=500,
n_cycles=2000,
integrator=femto.md.config.LangevinIntegrator(
timestep=1.0 * openmm.unit.femtosecond
),
lambda_sterics=absolv.config.DEFAULT_LAMBDA_STERICS_VACUUM,
lambda_electrostatics=absolv.config.DEFAULT_LAMBDA_ELECTROSTATICS_VACUUM,
)
if solvent_protocol is None:
solvent_protocol = absolv.config.EquilibriumProtocol(
production_protocol=absolv.config.HREMDProtocol(
n_steps_per_cycle=500,
n_cycles=1000,
integrator=femto.md.config.LangevinIntegrator(
timestep=4.0 * openmm.unit.femtosecond
),
trajectory_interval=1,
trajectory_enforce_pbc=True,
trajectory_interval=1,
),
lambda_sterics=absolv.config.DEFAULT_LAMBDA_STERICS_VACUUM,
lambda_electrostatics=absolv.config.DEFAULT_LAMBDA_ELECTROSTATICS_VACUUM,
)
solution_protocol = absolv.config.EquilibriumProtocol(
production_protocol=absolv.config.HREMDProtocol(
n_steps_per_cycle=500,
n_cycles=1000,
integrator=femto.md.config.LangevinIntegrator(
timestep=4.0 * openmm.unit.femtosecond
),
lambda_sterics=absolv.config.DEFAULT_LAMBDA_STERICS_SOLVENT,
lambda_electrostatics=absolv.config.DEFAULT_LAMBDA_ELECTROSTATICS_SOLVENT,
)
trajectory_interval=1,
trajectory_enforce_pbc=True,
),
lambda_sterics=absolv.config.DEFAULT_LAMBDA_STERICS_SOLVENT,
lambda_electrostatics=absolv.config.DEFAULT_LAMBDA_ELECTROSTATICS_SOLVENT,
)

config = absolv.config.Config(
temperature=temperature,
pressure=pressure,
alchemical_protocol_a=vacuum_protocol,
alchemical_protocol_b=solvent_protocol,
alchemical_protocol_a=solvent_a_protocol
if solvent_a_protocol is not None
else (vacuum_protocol if solvent_a is None else solution_protocol),
alchemical_protocol_b=solvent_b_protocol
if solvent_b_protocol is not None
else (vacuum_protocol if solvent_b is None else solution_protocol),
)

solute_mol = openff.toolkit.Molecule.from_rdkit(
smee.mm._utils.topology_to_rdkit(solute),
allow_undefined_stereo=True,
)
solvent_mol = openff.toolkit.Molecule.from_rdkit(
smee.mm._utils.topology_to_rdkit(solvent),
allow_undefined_stereo=True,
)

system_config = absolv.config.System(
solutes={solute_mol.to_smiles(mapped=True): 1},
solvent_a=None,
solvent_b={solvent_mol.to_smiles(mapped=True): n_solvent},
solvent_a=_to_solvent_dict(solvent_a, n_solvent_a),
solvent_b=_to_solvent_dict(solvent_b, n_solvent_b),
)

topologies = {
"solvent-a": smee.TensorSystem([solute], [1], is_periodic=False),
"solvent-b": smee.TensorSystem(
[solute, solvent], [1, n_solvent], is_periodic=True
),
"solvent-a": smee.TensorSystem([solute], [1], is_periodic=False)
if solvent_a is None
else smee.TensorSystem([solute, solvent_a], [1, n_solvent_a], is_periodic=True),
"solvent-b": smee.TensorSystem([solute], [1], is_periodic=False)
if solvent_b is None
else smee.TensorSystem([solute, solvent_b], [1, n_solvent_b], is_periodic=True),
}
pressures = {
"solvent-a": None,
"solvent-b": pressure.value_in_unit(openmm.unit.atmosphere),
"solvent-a": None
if solvent_a is None
else pressure.value_in_unit(openmm.unit.atmosphere),
"solvent-b": None
if solvent_b is None
else pressure.value_in_unit(openmm.unit.atmosphere),
}

for phase, topology in topologies.items():
Expand Down Expand Up @@ -173,8 +198,12 @@ def _parameterize(
config, prepared_system_a, prepared_system_b, "CUDA", output_dir, parallel=True
)

solvent_b_output = _extract_pure_solvent(force_field, output_dir / "solvent-b")
torch.save(solvent_b_output, output_dir / "solvent-b" / "pure.pt")
if solvent_a is not None:
solvent_a_output = _extract_pure_solvent(force_field, output_dir / "solvent-a")
torch.save(solvent_a_output, output_dir / "solvent-a" / "pure.pt")
if solvent_b is not None:
solvent_b_output = _extract_pure_solvent(force_field, output_dir / "solvent-b")
torch.save(solvent_b_output, output_dir / "solvent-b" / "pure.pt")

return result

Expand Down
4 changes: 3 additions & 1 deletion smee/tests/mm/test_fe.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def test_fe_ops(tmp_cwd):
output_dir = pathlib.Path("CCO")
output_dir.mkdir(parents=True, exist_ok=True)

smee.mm.generate_dg_solv_data(top_solute, top_solvent, ff, output_dir=output_dir)
smee.mm.generate_dg_solv_data(
top_solute, None, top_solvent, ff, output_dir=output_dir
)

params = ff.potentials_by_type["Electrostatics"].parameters
params.requires_grad_(True)
Expand Down

0 comments on commit 5d1a8da

Please sign in to comment.