From 88347f5cf6ca6b663a57cd45f511c0be34c3274d Mon Sep 17 00:00:00 2001 From: SimonBoothroyd Date: Tue, 16 Jul 2024 20:55:33 -0400 Subject: [PATCH] Add support for advanced OpenFF handler converters (#110) --- devtools/envs/base.yaml | 2 + smee/converters/openff/_openff.py | 133 ++++++++++++++++---- smee/tests/convertors/openff/test_openff.py | 46 +++++-- smee/tests/utils.py | 11 ++ 4 files changed, 159 insertions(+), 33 deletions(-) diff --git a/devtools/envs/base.yaml b/devtools/envs/base.yaml index 9aef575..ec7a9a1 100644 --- a/devtools/envs/base.yaml +++ b/devtools/envs/base.yaml @@ -17,6 +17,8 @@ dependencies: - pydantic - nnpops + - networkx + # Optional packages ### MM simulations diff --git a/smee/converters/openff/_openff.py b/smee/converters/openff/_openff.py index 6f0d7d6..a90554b 100644 --- a/smee/converters/openff/_openff.py +++ b/smee/converters/openff/_openff.py @@ -1,6 +1,8 @@ import importlib import inspect +import typing +import networkx import openff.interchange.components.potentials import openff.interchange.models import openff.interchange.smirnoff @@ -11,8 +13,17 @@ import smee import smee.geometry -_CONVERTERS = {} -_DEFAULT_UNITS = {} + +class _Converter(typing.NamedTuple): + fn: typing.Callable + """The function that will convert the parameters of a handler into tensors.""" + units: dict[str, openff.units.Unit] + """The default units of each parameter in the handler.""" + depends_on: list[str] | None + """The names of other converters that this converter should be run after.""" + + +_CONVERTERS: dict[str, _Converter] = {} _ANGSTROM = openff.units.unit.angstrom _RADIANS = openff.units.unit.radians @@ -25,7 +36,9 @@ def smirnoff_parameter_converter( - type_: str, default_units: dict[str, openff.units.Unit] + type_: str, + default_units: dict[str, openff.units.Unit], + depends_on: list[str] | None = None, ): """A decorator used to flag a function as being able to convert a parameter handlers parameters into tensors. @@ -33,14 +46,17 @@ def smirnoff_parameter_converter( Args: type_: The type of parameter handler that the decorated function can convert. default_units: The default units of each parameter in the handler. + depends_on: The names of other handlers that this handler depends on. When set, + the convert function should additionally take in a list of the already + converted potentials and return a new list of potentials that should either + include or replace the original potentials. """ def parameter_converter_inner(func): if type_ in _CONVERTERS: raise KeyError(f"A {type_} converter is already registered.") - _CONVERTERS[type_] = func - _DEFAULT_UNITS[type_] = default_units + _CONVERTERS[type_] = _Converter(func, default_units, depends_on) return func @@ -82,7 +98,7 @@ def _handlers_to_potential( _get_value( parameters_by_key[parameter_key], column, - _DEFAULT_UNITS[handler_type], + _CONVERTERS[handler_type].units, ) for column in parameter_cols ] @@ -100,7 +116,7 @@ def _handlers_to_potential( for handler in handlers } attributes_by_column = { - k: v.m_as(_DEFAULT_UNITS[handler_type][k]) + k: v.m_as(_CONVERTERS[handler_type].units[k]) for k, v in attributes_by_column.items() } attributes = torch.tensor( @@ -115,7 +131,7 @@ def _handlers_to_potential( parameter_keys=parameter_keys, parameter_cols=parameter_cols, parameter_units=tuple( - _DEFAULT_UNITS[handler_type][column] for column in parameter_cols + _CONVERTERS[handler_type].units[column] for column in parameter_cols ), attributes=attributes, attribute_cols=attribute_cols, @@ -123,7 +139,7 @@ def _handlers_to_potential( None if attribute_cols is None else tuple( - _DEFAULT_UNITS[handler_type][column] for column in attribute_cols + _CONVERTERS[handler_type].units[column] for column in attribute_cols ) ), ) @@ -250,7 +266,10 @@ def convert_handlers( handlers: list[openff.interchange.smirnoff.SMIRNOFFCollection], topologies: list[openff.toolkit.Topology], v_site_maps: list[smee.VSiteMap | None] | None = None, -) -> tuple[smee.TensorPotential, list[smee.ParameterMap]]: + potentials: ( + list[tuple[smee.TensorPotential, list[smee.ParameterMap]]] | None + ) = None, +) -> list[tuple[smee.TensorPotential, list[smee.ParameterMap]]]: """Convert a set of SMIRNOFF parameter handlers into a set of tensor potentials. Args: @@ -258,13 +277,14 @@ def convert_handlers( objects to convert. topologies: The topologies associated with each interchange object. v_site_maps: The v-site maps associated with each interchange object. + potentials: Already converted parameter handlers that may be required as + dependencies. Returns: The potential containing the values of the parameters in each handler collection, and a list of maps (one per topology) between molecule elements (e.g. bond indices) and parameter indices. - Examples: >>> from openff.toolkit import ForceField, Molecule @@ -283,21 +303,20 @@ def convert_handlers( >>> >>> vdw_potential, applied_vdw_parameters = convert_handlers(interchanges) """ + importlib.import_module("smee.converters.openff.nonbonded") + importlib.import_module("smee.converters.openff.valence") + handler_types = {handler.type for handler in handlers} assert len(handler_types) == 1, "multiple handler types found" handler_type = next(iter(handler_types)) assert len(handlers) == len(topologies), "mismatched number of topologies" - importlib.import_module("smee.converters.openff.nonbonded") - importlib.import_module("smee.converters.openff.valence") - if handler_type not in _CONVERTERS: raise NotImplementedError(f"{handler_type} handlers is not yet supported.") converter = _CONVERTERS[handler_type] - converter_spec = inspect.signature(converter) - + converter_spec = inspect.signature(converter.fn) converter_kwargs = {} if "topologies" in converter_spec.parameters: @@ -306,7 +325,44 @@ def convert_handlers( assert v_site_maps is not None, "v-site maps must be provided" converter_kwargs["v_site_maps"] = v_site_maps - return converter(handlers, **converter_kwargs) + potentials_by_type = ( + {} + if potentials is None + else {potential.type: (potential, maps) for potential, maps in potentials} + ) + + dependencies = {} + depends_on = converter.depends_on if converter.depends_on is not None else [] + + if len(depends_on) > 0: + missing_deps = {dep for dep in depends_on if dep not in potentials_by_type} + assert len(missing_deps) == 0, "missing dependencies" + + dependencies = {dep: potentials_by_type[dep] for dep in depends_on} + assert "dependencies" in converter_spec.parameters, "dependencies not accepted" + + if "dependencies" in converter_spec.parameters: + converter_kwargs["dependencies"] = dependencies + + converted = converter.fn(handlers, **converter_kwargs) + converted = [converted] if not isinstance(converted, list) else converted + + converted_by_type = { + potential.type: (potential, maps) for potential, maps in converted + } + assert len(converted_by_type) == len(converted), "duplicate potentials found" + + potentials_by_type = { + **{ + potential.type: (potential, maps) + for potential, maps in potentials_by_type.values() + if potential.type not in depends_on + and potential.type not in converted_by_type + }, + **converted_by_type, + } + + return [*potentials_by_type.values()] def _convert_topology( @@ -351,6 +407,26 @@ def _convert_topology( ) +def _resolve_conversion_order(handler_types: list[str]) -> list[str]: + """Resolve the order in which the handlers should be converted, based on their + dependencies with each other.""" + dep_graph = networkx.DiGraph() + + for handler_type in handler_types: + dep_graph.add_node(handler_type) + + for handler_type in handler_types: + converter = _CONVERTERS[handler_type] + + if converter.depends_on is None: + continue + + for dep in converter.depends_on: + dep_graph.add_edge(dep, handler_type) + + return list(networkx.topological_sort(dep_graph)) + + def convert_interchange( interchange: openff.interchange.Interchange | list[openff.interchange.Interchange], ) -> tuple[smee.TensorForceField, list[smee.TensorTopology]]: @@ -380,6 +456,9 @@ def convert_interchange( >>> >>> tensor_ff, tensor_topologies = convert_interchange(interchanges) """ + importlib.import_module("smee.converters.openff.nonbonded") + importlib.import_module("smee.converters.openff.valence") + interchanges = ( [interchange] if isinstance(interchange, openff.interchange.Interchange) @@ -417,19 +496,29 @@ def convert_interchange( if "Constraints" in handlers_by_type: constraints = _convert_constraints(handlers_by_type.pop("Constraints")) - potentials, parameter_maps_by_handler = [], {} + conversion_order = _resolve_conversion_order([*handlers_by_type]) + converted = [] + + for handler_type in conversion_order: + handlers = handlers_by_type[handler_type] - for handler_type, handlers in handlers_by_type.items(): if ( sum(len(handler.potentials) for handler in handlers if handler is not None) == 0 ): continue - potential, parameter_map = convert_handlers(handlers, topologies, v_site_maps) - potentials.append(potential) + converted = convert_handlers(handlers, topologies, v_site_maps, converted) - parameter_maps_by_handler[potential.type] = parameter_map + # handlers may either return multiple potentials, or condense multiple already + # converted potentials into a single one (e.g. electrostatics into some polarizable + # potential) + potentials = [] + parameter_maps_by_handler = {} + + for potential, parameter_maps in converted: + potentials.append(potential) + parameter_maps_by_handler[potential.type] = parameter_maps tensor_topologies = [ _convert_topology( diff --git a/smee/tests/convertors/openff/test_openff.py b/smee/tests/convertors/openff/test_openff.py index 1db39d8..dfa79fc 100644 --- a/smee/tests/convertors/openff/test_openff.py +++ b/smee/tests/convertors/openff/test_openff.py @@ -1,3 +1,5 @@ +import importlib + import openff.interchange.models import openff.toolkit import openff.units @@ -5,10 +7,12 @@ import torch import smee +import smee.tests.utils from smee.converters.openff._openff import ( _CONVERTERS, - _DEFAULT_UNITS, _convert_topology, + _Converter, + _resolve_conversion_order, convert_handlers, convert_interchange, smirnoff_parameter_converter, @@ -20,24 +24,27 @@ def test_parameter_converter(): lambda x: None ) assert "Dummy" in _CONVERTERS - assert "parm-a" in _DEFAULT_UNITS["Dummy"] + assert "parm-a" in _CONVERTERS["Dummy"].units with pytest.raises(KeyError, match="A Dummy converter is already"): smirnoff_parameter_converter("Dummy", {})(lambda x: None) del _CONVERTERS["Dummy"] - del _DEFAULT_UNITS["Dummy"] def test_convert_handler(ethanol, ethanol_interchange, mocker): - mock_result = mocker.MagicMock() + # avoid already registered converter error + importlib.import_module("smee.converters.openff.nonbonded") + + mock_deps = [(mocker.MagicMock(type="mock"), [mocker.MagicMock()])] + mock_result = (mocker.MagicMock(), []) - mock_vectorize = mocker.patch( - "smee.converters.openff.nonbonded.convert_vdw", + mock_convert = mocker.patch( + "smee.tests.utils.mock_convert_fn_with_deps", autospec=True, return_value=mock_result, ) - mocker.patch.dict(_CONVERTERS, {"vdW": mock_vectorize}) + mocker.patch.dict(_CONVERTERS, {"vdW": _Converter(mock_convert, {}, ["mock"])}) handlers = [ethanol_interchange.collections["vdW"]] topologies = [ethanol.to_topology()] @@ -52,12 +59,15 @@ def test_convert_handler(ethanol, ethanol_interchange, mocker): smee.VSiteMap([v_site], {v_site: ethanol.n_atoms}, torch.tensor([[0]])) ] - result = convert_handlers(handlers, topologies, v_site_maps) + result = convert_handlers(handlers, topologies, v_site_maps, mock_deps) - mock_vectorize.assert_called_once_with( - handlers, topologies=topologies, v_site_maps=v_site_maps + mock_convert.assert_called_once_with( + handlers, + topologies=topologies, + v_site_maps=v_site_maps, + dependencies={"mock": mock_deps[0]}, ) - assert result == mock_result + assert result == [mock_result] def test_convert_topology(formaldehyde, mocker): @@ -91,6 +101,20 @@ def test_convert_topology(formaldehyde, mocker): assert topology.constraints == constraints +def test_resolve_conversion_order(mocker): + mocker.patch.dict( + _CONVERTERS, + { + "a": _Converter(mocker.MagicMock(), {}, ["c"]), + "b": _Converter(mocker.MagicMock(), {}, []), + "c": _Converter(mocker.MagicMock(), {}, ["b"]), + }, + ) + + order = _resolve_conversion_order(["a", "b", "c"]) + assert order == ["b", "c", "a"] + + def test_convert_interchange(): force_field = openff.toolkit.ForceField() force_field.get_parameter_handler("Electrostatics") diff --git a/smee/tests/utils.py b/smee/tests/utils.py index 687d7c1..a62163f 100644 --- a/smee/tests/utils.py +++ b/smee/tests/utils.py @@ -14,6 +14,17 @@ LJParam = typing.NamedTuple("LJParam", [("eps", float), ("sig", float)]) +def mock_convert_fn_with_deps( + handlers: list[openff.interchange.smirnoff.SMIRNOFFvdWCollection], + topologies: list[openff.toolkit.Topology], + v_site_maps: list[smee.VSiteMap | None], + dependencies: dict[ + str, tuple[smee.TensorPotential, list[smee.NonbondedParameterMap]] + ], +) -> tuple[smee.TensorPotential, list[smee.NonbondedParameterMap]]: + raise NotImplementedError() + + def convert_lj_to_dexp(potential: smee.TensorPotential): potential.fn = smee.EnergyFn.VDW_DEXP