Skip to content

Commit

Permalink
[dG] Fix extracting solvent when v-sites present on solute (#126)
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonBoothroyd authored Nov 26, 2024
1 parent 3dd387b commit 97d6217
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 6 deletions.
9 changes: 3 additions & 6 deletions smee/mm/_fe.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,16 @@ def _extract_pure_solvent(
device = force_field.potentials[0].parameters.device
dtype = force_field.potentials[0].parameters.dtype

system, beta, pressure, u_kn, n_k, xyz, box = _load_samples(
system, _, _, _, _, xyz, box = _load_samples(
solute, solvent, output_dir, device, dtype, coord_state_idx=-1
)

if len(system.topologies) != 2 or system.n_copies[0] != 1:
raise NotImplementedError("only single solute systems are supported.")

n_solute_atoms = system.topologies[0].n_atoms
xyz = xyz[:, n_solute_atoms:, :]
xyz = xyz[:, solute.n_particles :, :]

system = smee.TensorSystem(
[system.topologies[1]], [system.n_copies[1]], is_periodic=True
)
system = smee.TensorSystem([solvent], [system.n_copies[1]], is_periodic=True)
energy = _compute_energy(system, force_field, xyz, box)

return xyz, box, energy
Expand Down
41 changes: 41 additions & 0 deletions smee/tests/mm/test_fe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,35 @@

import openff.interchange
import openff.toolkit
import openff.units
import openmm.unit
import pytest
import torch

import smee.converters
import smee.mm
import smee.mm._fe


def load_systems(solute: str, solvent: str):
ff_off = openff.toolkit.ForceField("openff-2.0.0.offxml")

v_site_handler = ff_off.get_parameter_handler("VirtualSites")
v_site_handler.add_parameter(
{
"type": "DivalentLonePair",
"match": "once",
"smirks": "[*:2][#7:1][*:3]",
"distance": 0.4 * openff.units.unit.angstrom,
"epsilon": 0.0 * openff.units.unit.kilojoule_per_mole,
"sigma": 0.1 * openff.units.unit.nanometer,
"outOfPlaneAngle": 0.0 * openff.units.unit.degree,
"charge_increment1": 0.0 * openff.units.unit.elementary_charge,
"charge_increment2": 0.0 * openff.units.unit.elementary_charge,
"charge_increment3": 0.0 * openff.units.unit.elementary_charge,
}
)

solute_inter = openff.interchange.Interchange.from_smirnoff(
ff_off,
openff.toolkit.Molecule.from_smiles(solute).to_topology(),
Expand All @@ -29,6 +48,28 @@ def load_systems(solute: str, solvent: str):
return top_solute, top_solvent, ff


def test_extract_pure_solvent(tmp_cwd, mocker):
top_solute, top_solvent, ff = load_systems("c1ccncc1", "O")

system = smee.TensorSystem([top_solute, top_solvent], [1, 10], True)
xyz, box = smee.mm.generate_system_coords(system, ff)

xyz = torch.tensor(xyz.value_in_unit(openmm.unit.angstrom)).unsqueeze(0)
box = torch.tensor(box.value_in_unit(openmm.unit.angstrom)).unsqueeze(0) * 10.0

mocker.patch(
"smee.mm._fe._load_samples",
return_value=(system, None, None, None, None, xyz, box),
)

xyz_solv, _, _ = smee.mm._fe._extract_pure_solvent(
top_solute, top_solvent, ff, tmp_cwd
)

assert xyz_solv.shape == (1, 30, 3)
assert torch.allclose(xyz_solv, xyz[:, 12:, :])


@pytest.mark.fe
def test_fe_ops(tmp_cwd):
# taken from a run on commit ec3d272b466f761ed838e16a5ba7b97ceadc463b
Expand Down

0 comments on commit 97d6217

Please sign in to comment.