From 97d62179f325a73a3ce7c8e59baace16f7246af8 Mon Sep 17 00:00:00 2001 From: SimonBoothroyd Date: Tue, 26 Nov 2024 13:00:09 +0000 Subject: [PATCH] [dG] Fix extracting solvent when v-sites present on solute (#126) --- smee/mm/_fe.py | 9 +++------ smee/tests/mm/test_fe.py | 41 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 6 deletions(-) diff --git a/smee/mm/_fe.py b/smee/mm/_fe.py index 7a2d196..d8a5d94 100644 --- a/smee/mm/_fe.py +++ b/smee/mm/_fe.py @@ -47,19 +47,16 @@ def _extract_pure_solvent( device = force_field.potentials[0].parameters.device dtype = force_field.potentials[0].parameters.dtype - system, beta, pressure, u_kn, n_k, xyz, box = _load_samples( + system, _, _, _, _, xyz, box = _load_samples( solute, solvent, output_dir, device, dtype, coord_state_idx=-1 ) if len(system.topologies) != 2 or system.n_copies[0] != 1: raise NotImplementedError("only single solute systems are supported.") - n_solute_atoms = system.topologies[0].n_atoms - xyz = xyz[:, n_solute_atoms:, :] + xyz = xyz[:, solute.n_particles :, :] - system = smee.TensorSystem( - [system.topologies[1]], [system.n_copies[1]], is_periodic=True - ) + system = smee.TensorSystem([solvent], [system.n_copies[1]], is_periodic=True) energy = _compute_energy(system, force_field, xyz, box) return xyz, box, energy diff --git a/smee/tests/mm/test_fe.py b/smee/tests/mm/test_fe.py index 73dc0f4..0ea3207 100644 --- a/smee/tests/mm/test_fe.py +++ b/smee/tests/mm/test_fe.py @@ -2,16 +2,35 @@ import openff.interchange import openff.toolkit +import openff.units +import openmm.unit import pytest import torch import smee.converters import smee.mm +import smee.mm._fe def load_systems(solute: str, solvent: str): ff_off = openff.toolkit.ForceField("openff-2.0.0.offxml") + v_site_handler = ff_off.get_parameter_handler("VirtualSites") + v_site_handler.add_parameter( + { + "type": "DivalentLonePair", + "match": "once", + "smirks": "[*:2][#7:1][*:3]", + "distance": 0.4 * openff.units.unit.angstrom, + "epsilon": 0.0 * openff.units.unit.kilojoule_per_mole, + "sigma": 0.1 * openff.units.unit.nanometer, + "outOfPlaneAngle": 0.0 * openff.units.unit.degree, + "charge_increment1": 0.0 * openff.units.unit.elementary_charge, + "charge_increment2": 0.0 * openff.units.unit.elementary_charge, + "charge_increment3": 0.0 * openff.units.unit.elementary_charge, + } + ) + solute_inter = openff.interchange.Interchange.from_smirnoff( ff_off, openff.toolkit.Molecule.from_smiles(solute).to_topology(), @@ -29,6 +48,28 @@ def load_systems(solute: str, solvent: str): return top_solute, top_solvent, ff +def test_extract_pure_solvent(tmp_cwd, mocker): + top_solute, top_solvent, ff = load_systems("c1ccncc1", "O") + + system = smee.TensorSystem([top_solute, top_solvent], [1, 10], True) + xyz, box = smee.mm.generate_system_coords(system, ff) + + xyz = torch.tensor(xyz.value_in_unit(openmm.unit.angstrom)).unsqueeze(0) + box = torch.tensor(box.value_in_unit(openmm.unit.angstrom)).unsqueeze(0) * 10.0 + + mocker.patch( + "smee.mm._fe._load_samples", + return_value=(system, None, None, None, None, xyz, box), + ) + + xyz_solv, _, _ = smee.mm._fe._extract_pure_solvent( + top_solute, top_solvent, ff, tmp_cwd + ) + + assert xyz_solv.shape == (1, 30, 3) + assert torch.allclose(xyz_solv, xyz[:, 12:, :]) + + @pytest.mark.fe def test_fe_ops(tmp_cwd): # taken from a run on commit ec3d272b466f761ed838e16a5ba7b97ceadc463b