Skip to content

Commit

Permalink
Store enough extra info to convert to OpenMM (#70)
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonBoothroyd authored Oct 6, 2023
1 parent 62bf8e3 commit 56284ca
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 14 deletions.
70 changes: 62 additions & 8 deletions smee/ff/_ff.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,31 @@ class TensorTopology:
"""A tensor representation of a molecular topology that has been assigned force
field parameters."""

n_atoms: int
"""The number of atoms in the topology."""
atomic_nums: torch.Tensor
"""The atomic numbers of each atom in the topology with ``shape=(n_atoms,)``"""
formal_charges: torch.Tensor
"""The formal charge of each atom in the topology with ``shape=(n_atoms,)``"""

bond_idxs: torch.Tensor
"""The indices of the atoms involved in each bond with ``shape=(n_bonds, 2)``"""
bond_orders: torch.Tensor
"""The bond orders of each bond with ``shape=(n_bonds,)``"""

parameters: dict[str, ParameterMap]
"""The parameters that have been assigned to the topology."""
v_sites: VSiteMap | None = None
"""The v-sites that have been assigned to the topology."""

@property
def n_atoms(self) -> int:
"""The number of atoms in the topology."""
return len(self.atomic_nums)

@property
def n_bonds(self) -> int:
"""The number of bonds in the topology."""
return len(self.bond_idxs)


@dataclasses.dataclass
class TensorPotential:
Expand Down Expand Up @@ -361,7 +378,7 @@ def convert_handlers(
handlers: list[openff.interchange.smirnoff._base.SMIRNOFFCollection],
topologies: list[openff.toolkit.Topology],
v_site_maps: list[VSiteMap | None] | None = None,
):
) -> tuple[TensorPotential, list[ParameterMap]]:
"""Convert a set of SMIRNOFF parameter handlers into a set of tensor potentials.
Args:
Expand Down Expand Up @@ -420,6 +437,43 @@ def convert_handlers(
return converter(handlers, **converter_kwargs)


def _convert_topology(
topology: openff.toolkit.Topology,
parameters: dict[str, ParameterMap],
v_sites: VSiteMap | None,
) -> TensorTopology:
"""Convert an OpenFF topology into a tensor topology.
Args:
topology: The topology to convert.
parameters: The parameters assigned to the topology.
v_sites: The v-sites assigned to the topology.
Returns:
The converted topology.
"""

atomic_nums = torch.tensor([atom.atomic_number for atom in topology.atoms])

formal_charges = torch.tensor(
[atom.formal_charge.m_as(openff.units.unit.e) for atom in topology.atoms]
)

bond_idxs = torch.tensor(
[[bond.atom1_index, bond.atom2_index] for bond in topology.bonds]
)
bond_orders = torch.tensor([bond.bond_order for bond in topology.bonds])

return TensorTopology(
atomic_nums=atomic_nums,
formal_charges=formal_charges,
bond_idxs=bond_idxs,
bond_orders=bond_orders,
parameters=parameters,
v_sites=v_sites,
)


def convert_interchange(
interchange: openff.interchange.Interchange | list[openff.interchange.Interchange],
) -> tuple[TensorForceField, list[TensorTopology]]:
Expand Down Expand Up @@ -498,15 +552,15 @@ def convert_interchange(
parameter_maps_by_handler[handler_type] = parameter_map

tensor_topologies = [
TensorTopology(
n_atoms=topologies[i].n_atoms,
parameters={
_convert_topology(
topology,
{
potential.type: parameter_maps_by_handler[potential.type][i]
for potential in potentials
},
v_sites=v_site_maps[i],
v_site_maps[i],
)
for i in range(len(topologies))
for i, topology in enumerate(topologies)
]

tensor_force_field = TensorForceField(potentials, v_sites)
Expand Down
17 changes: 13 additions & 4 deletions smee/ff/nonbonded.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def convert_nonbonded_handlers(
v_site_maps: The virtual site maps associated with each handler.
parameter_cols: The ordering of the parameter array columns.
attribute_cols: The handler attributes to include in the potential *in addition*
to the intr-amolecular scaling factors.
to the intra-molecular scaling factors.
Returns:
The potential containing tensors of the parameter values, and a list of
Expand Down Expand Up @@ -115,7 +115,8 @@ def convert_nonbonded_handlers(
exclusion_to_scale = smee.utils.find_exclusions(topology, v_site_map)
exclusions = torch.tensor([*exclusion_to_scale])
exclusion_scale_idxs = torch.tensor(
[[attribute_to_idx[scale]] for scale in exclusion_to_scale.values()]
[[attribute_to_idx[scale]] for scale in exclusion_to_scale.values()],
dtype=torch.int64,
)

parameter_map = smee.ff.NonbondedParameterMap(
Expand All @@ -137,6 +138,8 @@ def convert_nonbonded_handlers(
"scale_13": _UNITLESS,
"scale_14": _UNITLESS,
"scale_15": _UNITLESS,
"cutoff": _ANGSTROM,
"switch_width": _ANGSTROM,
},
)
def convert_vdw(
Expand All @@ -152,7 +155,12 @@ def convert_vdw(
raise NotImplementedError("only Lorentz-Berthelot mixing rules are supported.")

return convert_nonbonded_handlers(
handlers, "vdW", topologies, v_site_maps, ("epsilon", "sigma")
handlers,
"vdW",
topologies,
v_site_maps,
("epsilon", "sigma"),
("cutoff", "switch_width"),
)


Expand Down Expand Up @@ -204,6 +212,7 @@ def _make_v_site_electrostatics_compatible(handlers: list[_ElectrostaticParamete
"scale_13": _UNITLESS,
"scale_14": _UNITLESS,
"scale_15": _UNITLESS,
"cutoff": _ANGSTROM,
},
)
def convert_electrostatics(
Expand All @@ -215,5 +224,5 @@ def convert_electrostatics(
_make_v_site_electrostatics_compatible(handlers)

return convert_nonbonded_handlers(
handlers, "Electrostatics", topologies, v_site_maps, ("charge",)
handlers, "Electrostatics", topologies, v_site_maps, ("charge",), ("cutoff",)
)
2 changes: 1 addition & 1 deletion smee/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def ethanol_interchange(ethanol, default_force_field) -> openff.interchange.Inte

@pytest.fixture(scope="module")
def formaldehyde() -> openff.toolkit.Molecule:
"""Returns an OpenFF formaldehyde molecule with a fixed atom order.."""
"""Returns an OpenFF formaldehyde molecule with a fixed atom order."""

return openff.toolkit.Molecule.from_mapped_smiles("[H:3][C:1](=[O:2])[H:4]")

Expand Down
30 changes: 30 additions & 0 deletions smee/tests/ff/test_ff.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
_CONVERTERS,
_DEFAULT_UNITS,
VSiteMap,
_convert_topology,
convert_handlers,
convert_interchange,
parameter_converter,
Expand Down Expand Up @@ -55,6 +56,35 @@ def test_convert_handler(ethanol, ethanol_interchange, mocker):
assert result == mock_result


def test_convert_topology(formaldehyde, mocker):
parameters = mocker.MagicMock()
v_sites = VSiteMap([], {}, torch.tensor([]))

topology = _convert_topology(formaldehyde, parameters, v_sites)

assert topology.n_atoms == 4
assert topology.n_bonds == 3

expected_atomic_nums = torch.tensor([6, 8, 1, 1])
expected_formal_charges = torch.tensor([0, 0, 0, 0])

expected_bond_idxs = torch.tensor([[0, 1], [0, 2], [0, 3]])
expected_bond_orders = torch.tensor([2, 1, 1])

assert topology.atomic_nums.shape == expected_atomic_nums.shape
assert torch.allclose(topology.atomic_nums, expected_atomic_nums)
assert topology.formal_charges.shape == expected_formal_charges.shape
assert torch.allclose(topology.formal_charges, expected_formal_charges)

assert topology.bond_idxs.shape == expected_bond_idxs.shape
assert torch.allclose(topology.bond_idxs, expected_bond_idxs)
assert topology.bond_orders.shape == expected_bond_orders.shape
assert torch.allclose(topology.bond_orders, expected_bond_orders)

assert topology.parameters == parameters
assert topology.v_sites == v_sites


def test_convert_interchange():
force_field = openff.toolkit.ForceField()
force_field.get_parameter_handler("Electrostatics")
Expand Down
5 changes: 4 additions & 1 deletion smee/tests/ff/test_nonbonded.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@ def test_convert_electrostatics_am1bcc(ethanol, ethanol_interchange):
assert potential.type == "Electrostatics"
assert potential.fn == "coul"

expected_attributes = torch.tensor([0.0, 0.0, 5.0 / 6.0, 1.0], dtype=torch.float64)
expected_attributes = torch.tensor(
[0.0, 0.0, 5.0 / 6.0, 1.0, 9.0], dtype=torch.float64
)
assert torch.allclose(potential.attributes, expected_attributes)
assert potential.attribute_cols == (
"scale_12",
"scale_13",
"scale_14",
"scale_15",
"cutoff",
)

assert potential.parameter_cols == ("charge",)
Expand Down

0 comments on commit 56284ca

Please sign in to comment.