Skip to content

Commit

Permalink
Add support for advanced OpenFF handler converters (#110)
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonBoothroyd authored Jul 17, 2024
1 parent dafbe1c commit 88347f5
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 33 deletions.
2 changes: 2 additions & 0 deletions devtools/envs/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ dependencies:
- pydantic
- nnpops

- networkx

# Optional packages

### MM simulations
Expand Down
133 changes: 111 additions & 22 deletions smee/converters/openff/_openff.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import importlib
import inspect
import typing

import networkx
import openff.interchange.components.potentials
import openff.interchange.models
import openff.interchange.smirnoff
Expand All @@ -11,8 +13,17 @@
import smee
import smee.geometry

_CONVERTERS = {}
_DEFAULT_UNITS = {}

class _Converter(typing.NamedTuple):
fn: typing.Callable
"""The function that will convert the parameters of a handler into tensors."""
units: dict[str, openff.units.Unit]
"""The default units of each parameter in the handler."""
depends_on: list[str] | None
"""The names of other converters that this converter should be run after."""


_CONVERTERS: dict[str, _Converter] = {}

_ANGSTROM = openff.units.unit.angstrom
_RADIANS = openff.units.unit.radians
Expand All @@ -25,22 +36,27 @@


def smirnoff_parameter_converter(
type_: str, default_units: dict[str, openff.units.Unit]
type_: str,
default_units: dict[str, openff.units.Unit],
depends_on: list[str] | None = None,
):
"""A decorator used to flag a function as being able to convert a parameter handlers
parameters into tensors.
Args:
type_: The type of parameter handler that the decorated function can convert.
default_units: The default units of each parameter in the handler.
depends_on: The names of other handlers that this handler depends on. When set,
the convert function should additionally take in a list of the already
converted potentials and return a new list of potentials that should either
include or replace the original potentials.
"""

def parameter_converter_inner(func):
if type_ in _CONVERTERS:
raise KeyError(f"A {type_} converter is already registered.")

_CONVERTERS[type_] = func
_DEFAULT_UNITS[type_] = default_units
_CONVERTERS[type_] = _Converter(func, default_units, depends_on)

return func

Expand Down Expand Up @@ -82,7 +98,7 @@ def _handlers_to_potential(
_get_value(
parameters_by_key[parameter_key],
column,
_DEFAULT_UNITS[handler_type],
_CONVERTERS[handler_type].units,
)
for column in parameter_cols
]
Expand All @@ -100,7 +116,7 @@ def _handlers_to_potential(
for handler in handlers
}
attributes_by_column = {
k: v.m_as(_DEFAULT_UNITS[handler_type][k])
k: v.m_as(_CONVERTERS[handler_type].units[k])
for k, v in attributes_by_column.items()
}
attributes = torch.tensor(
Expand All @@ -115,15 +131,15 @@ def _handlers_to_potential(
parameter_keys=parameter_keys,
parameter_cols=parameter_cols,
parameter_units=tuple(
_DEFAULT_UNITS[handler_type][column] for column in parameter_cols
_CONVERTERS[handler_type].units[column] for column in parameter_cols
),
attributes=attributes,
attribute_cols=attribute_cols,
attribute_units=(
None
if attribute_cols is None
else tuple(
_DEFAULT_UNITS[handler_type][column] for column in attribute_cols
_CONVERTERS[handler_type].units[column] for column in attribute_cols
)
),
)
Expand Down Expand Up @@ -250,21 +266,25 @@ def convert_handlers(
handlers: list[openff.interchange.smirnoff.SMIRNOFFCollection],
topologies: list[openff.toolkit.Topology],
v_site_maps: list[smee.VSiteMap | None] | None = None,
) -> tuple[smee.TensorPotential, list[smee.ParameterMap]]:
potentials: (
list[tuple[smee.TensorPotential, list[smee.ParameterMap]]] | None
) = None,
) -> list[tuple[smee.TensorPotential, list[smee.ParameterMap]]]:
"""Convert a set of SMIRNOFF parameter handlers into a set of tensor potentials.
Args:
handlers: The SMIRNOFF parameter handler collections for a set of interchange
objects to convert.
topologies: The topologies associated with each interchange object.
v_site_maps: The v-site maps associated with each interchange object.
potentials: Already converted parameter handlers that may be required as
dependencies.
Returns:
The potential containing the values of the parameters in each handler
collection, and a list of maps (one per topology) between molecule elements
(e.g. bond indices) and parameter indices.
Examples:
>>> from openff.toolkit import ForceField, Molecule
Expand All @@ -283,21 +303,20 @@ def convert_handlers(
>>>
>>> vdw_potential, applied_vdw_parameters = convert_handlers(interchanges)
"""
importlib.import_module("smee.converters.openff.nonbonded")
importlib.import_module("smee.converters.openff.valence")

handler_types = {handler.type for handler in handlers}
assert len(handler_types) == 1, "multiple handler types found"
handler_type = next(iter(handler_types))

assert len(handlers) == len(topologies), "mismatched number of topologies"

importlib.import_module("smee.converters.openff.nonbonded")
importlib.import_module("smee.converters.openff.valence")

if handler_type not in _CONVERTERS:
raise NotImplementedError(f"{handler_type} handlers is not yet supported.")

converter = _CONVERTERS[handler_type]
converter_spec = inspect.signature(converter)

converter_spec = inspect.signature(converter.fn)
converter_kwargs = {}

if "topologies" in converter_spec.parameters:
Expand All @@ -306,7 +325,44 @@ def convert_handlers(
assert v_site_maps is not None, "v-site maps must be provided"
converter_kwargs["v_site_maps"] = v_site_maps

return converter(handlers, **converter_kwargs)
potentials_by_type = (
{}
if potentials is None
else {potential.type: (potential, maps) for potential, maps in potentials}
)

dependencies = {}
depends_on = converter.depends_on if converter.depends_on is not None else []

if len(depends_on) > 0:
missing_deps = {dep for dep in depends_on if dep not in potentials_by_type}
assert len(missing_deps) == 0, "missing dependencies"

dependencies = {dep: potentials_by_type[dep] for dep in depends_on}
assert "dependencies" in converter_spec.parameters, "dependencies not accepted"

if "dependencies" in converter_spec.parameters:
converter_kwargs["dependencies"] = dependencies

converted = converter.fn(handlers, **converter_kwargs)
converted = [converted] if not isinstance(converted, list) else converted

converted_by_type = {
potential.type: (potential, maps) for potential, maps in converted
}
assert len(converted_by_type) == len(converted), "duplicate potentials found"

potentials_by_type = {
**{
potential.type: (potential, maps)
for potential, maps in potentials_by_type.values()
if potential.type not in depends_on
and potential.type not in converted_by_type
},
**converted_by_type,
}

return [*potentials_by_type.values()]


def _convert_topology(
Expand Down Expand Up @@ -351,6 +407,26 @@ def _convert_topology(
)


def _resolve_conversion_order(handler_types: list[str]) -> list[str]:
"""Resolve the order in which the handlers should be converted, based on their
dependencies with each other."""
dep_graph = networkx.DiGraph()

for handler_type in handler_types:
dep_graph.add_node(handler_type)

for handler_type in handler_types:
converter = _CONVERTERS[handler_type]

if converter.depends_on is None:
continue

for dep in converter.depends_on:
dep_graph.add_edge(dep, handler_type)

return list(networkx.topological_sort(dep_graph))


def convert_interchange(
interchange: openff.interchange.Interchange | list[openff.interchange.Interchange],
) -> tuple[smee.TensorForceField, list[smee.TensorTopology]]:
Expand Down Expand Up @@ -380,6 +456,9 @@ def convert_interchange(
>>>
>>> tensor_ff, tensor_topologies = convert_interchange(interchanges)
"""
importlib.import_module("smee.converters.openff.nonbonded")
importlib.import_module("smee.converters.openff.valence")

interchanges = (
[interchange]
if isinstance(interchange, openff.interchange.Interchange)
Expand Down Expand Up @@ -417,19 +496,29 @@ def convert_interchange(
if "Constraints" in handlers_by_type:
constraints = _convert_constraints(handlers_by_type.pop("Constraints"))

potentials, parameter_maps_by_handler = [], {}
conversion_order = _resolve_conversion_order([*handlers_by_type])
converted = []

for handler_type in conversion_order:
handlers = handlers_by_type[handler_type]

for handler_type, handlers in handlers_by_type.items():
if (
sum(len(handler.potentials) for handler in handlers if handler is not None)
== 0
):
continue

potential, parameter_map = convert_handlers(handlers, topologies, v_site_maps)
potentials.append(potential)
converted = convert_handlers(handlers, topologies, v_site_maps, converted)

parameter_maps_by_handler[potential.type] = parameter_map
# handlers may either return multiple potentials, or condense multiple already
# converted potentials into a single one (e.g. electrostatics into some polarizable
# potential)
potentials = []
parameter_maps_by_handler = {}

for potential, parameter_maps in converted:
potentials.append(potential)
parameter_maps_by_handler[potential.type] = parameter_maps

tensor_topologies = [
_convert_topology(
Expand Down
46 changes: 35 additions & 11 deletions smee/tests/convertors/openff/test_openff.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import importlib

import openff.interchange.models
import openff.toolkit
import openff.units
import pytest
import torch

import smee
import smee.tests.utils
from smee.converters.openff._openff import (
_CONVERTERS,
_DEFAULT_UNITS,
_convert_topology,
_Converter,
_resolve_conversion_order,
convert_handlers,
convert_interchange,
smirnoff_parameter_converter,
Expand All @@ -20,24 +24,27 @@ def test_parameter_converter():
lambda x: None
)
assert "Dummy" in _CONVERTERS
assert "parm-a" in _DEFAULT_UNITS["Dummy"]
assert "parm-a" in _CONVERTERS["Dummy"].units

with pytest.raises(KeyError, match="A Dummy converter is already"):
smirnoff_parameter_converter("Dummy", {})(lambda x: None)

del _CONVERTERS["Dummy"]
del _DEFAULT_UNITS["Dummy"]


def test_convert_handler(ethanol, ethanol_interchange, mocker):
mock_result = mocker.MagicMock()
# avoid already registered converter error
importlib.import_module("smee.converters.openff.nonbonded")

mock_deps = [(mocker.MagicMock(type="mock"), [mocker.MagicMock()])]
mock_result = (mocker.MagicMock(), [])

mock_vectorize = mocker.patch(
"smee.converters.openff.nonbonded.convert_vdw",
mock_convert = mocker.patch(
"smee.tests.utils.mock_convert_fn_with_deps",
autospec=True,
return_value=mock_result,
)
mocker.patch.dict(_CONVERTERS, {"vdW": mock_vectorize})
mocker.patch.dict(_CONVERTERS, {"vdW": _Converter(mock_convert, {}, ["mock"])})

handlers = [ethanol_interchange.collections["vdW"]]
topologies = [ethanol.to_topology()]
Expand All @@ -52,12 +59,15 @@ def test_convert_handler(ethanol, ethanol_interchange, mocker):
smee.VSiteMap([v_site], {v_site: ethanol.n_atoms}, torch.tensor([[0]]))
]

result = convert_handlers(handlers, topologies, v_site_maps)
result = convert_handlers(handlers, topologies, v_site_maps, mock_deps)

mock_vectorize.assert_called_once_with(
handlers, topologies=topologies, v_site_maps=v_site_maps
mock_convert.assert_called_once_with(
handlers,
topologies=topologies,
v_site_maps=v_site_maps,
dependencies={"mock": mock_deps[0]},
)
assert result == mock_result
assert result == [mock_result]


def test_convert_topology(formaldehyde, mocker):
Expand Down Expand Up @@ -91,6 +101,20 @@ def test_convert_topology(formaldehyde, mocker):
assert topology.constraints == constraints


def test_resolve_conversion_order(mocker):
mocker.patch.dict(
_CONVERTERS,
{
"a": _Converter(mocker.MagicMock(), {}, ["c"]),
"b": _Converter(mocker.MagicMock(), {}, []),
"c": _Converter(mocker.MagicMock(), {}, ["b"]),
},
)

order = _resolve_conversion_order(["a", "b", "c"])
assert order == ["b", "c", "a"]


def test_convert_interchange():
force_field = openff.toolkit.ForceField()
force_field.get_parameter_handler("Electrostatics")
Expand Down
11 changes: 11 additions & 0 deletions smee/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,17 @@
LJParam = typing.NamedTuple("LJParam", [("eps", float), ("sig", float)])


def mock_convert_fn_with_deps(
handlers: list[openff.interchange.smirnoff.SMIRNOFFvdWCollection],
topologies: list[openff.toolkit.Topology],
v_site_maps: list[smee.VSiteMap | None],
dependencies: dict[
str, tuple[smee.TensorPotential, list[smee.NonbondedParameterMap]]
],
) -> tuple[smee.TensorPotential, list[smee.NonbondedParameterMap]]:
raise NotImplementedError()


def convert_lj_to_dexp(potential: smee.TensorPotential):
potential.fn = smee.EnergyFn.VDW_DEXP

Expand Down

0 comments on commit 88347f5

Please sign in to comment.