Skip to content

Commit

Permalink
Running RxnMapper in batches and without raising errors (#36)
Browse files Browse the repository at this point in the history
  • Loading branch information
avaucher authored Feb 7, 2023
1 parent ca5f679 commit df354df
Show file tree
Hide file tree
Showing 9 changed files with 251 additions and 20 deletions.
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,18 @@ The results contain the mapped reactions and confidence scores:
'confidence': 0.9704424331552834}]
```

To account for batching and error handling automatically, you can use `BatchedMapper` instead:
```python
from rxnmapper import BatchedMapper
rxn_mapper = BatchedMapper(batch_size=32)
rxns = ['CC[O-]~[Na+].BrCC>>CCOCC', 'invalid>>reaction']

# The following calls work with input of arbitrary size. Also, they do not raise
# any exceptions but will return ">>" or an empty dictionary for the second reaction.
results = list(rxn_mapper.map_reactions(rxns)) # results as strings directly
results = list(rxn_mapper.map_reactions_with_info(rxns)) # results as dictionaries (as above)
```

### Testing

You can run the examples above with the test suite as well:
Expand Down
2 changes: 2 additions & 0 deletions rxnmapper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
__version__ = "0.2.4" # managed by bump2version


from .batched_mapper import BatchedMapper
from .core import RXNMapper

__all__ = [
"BatchedMapper",
"RXNMapper",
]
105 changes: 105 additions & 0 deletions rxnmapper/batched_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import logging
from typing import Any, Dict, Iterable, Iterator, List

from rxn.utilities.containers import chunker

from .core import RXNMapper

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

# Alias for what the original mapper returns
ResultWithInfo = Dict[str, Any]


class BatchedMapper:
"""
Class to atom-map reactions in batches, with error control.
"""

def __init__(
self,
batch_size: int,
canonicalize: bool = False,
placeholder_for_invalid: str = ">>",
):
self.mapper = RXNMapper()
self.batch_size = batch_size
self.canonicalize = canonicalize
self.placeholder_for_invalid = placeholder_for_invalid

def map_reactions(self, reaction_smiles: Iterable[str]) -> Iterator[str]:
"""Map the given reactions, returning the mapped SMILES strings.
Args:
reaction_smiles: reaction SMILES strings to map.
Returns:
iterator over mapped strings; a placeholder is returned for the
entries that failed.
"""
for result in self.map_reactions_with_info(reaction_smiles):
if result == {}:
yield self.placeholder_for_invalid
else:
yield result["mapped_rxn"]

def map_reactions_with_info(
self, reaction_smiles: Iterable[str], detailed: bool = False
) -> Iterator[ResultWithInfo]:
"""Map the given reactions, returning the results as dictionaries.
Args:
reaction_smiles: reaction SMILES strings to map.
detailed: detailed output or not.
Returns:
iterator over dictionaries (in the format returned by the RXNMapper class);
an empty dictionary is returned for the entries that failed.
"""
for rxns_chunk in chunker(reaction_smiles, chunk_size=self.batch_size):
yield from self._map_reaction_batch(rxns_chunk, detailed=detailed)

def _map_reaction_batch(
self, reaction_batch: List[str], detailed: bool
) -> Iterator[ResultWithInfo]:
try:
yield from self._try_map_reaction_batch(reaction_batch, detailed=detailed)
except Exception:
logger.warning(
f"Error while mapping chunk of {len(reaction_batch)} reactions. "
"Mapping them individually."
)
yield from self._map_reactions_one_by_one(reaction_batch, detailed=detailed)

def _try_map_reaction_batch(
self, reaction_batch: List[str], detailed: bool
) -> List[ResultWithInfo]:
"""
Map a reaction batch, without error handling.
Note: we return a list, not a generator function, to avoid returning partial
results.
"""
return self.mapper.get_attention_guided_atom_maps(
reaction_batch,
canonicalize_rxns=self.canonicalize,
detailed_output=detailed,
)

def _map_reactions_one_by_one(
self, reaction_batch: Iterable[str], detailed: bool
) -> Iterator[ResultWithInfo]:
"""
Map a reaction batch, one reaction at a time.
Reactions causing an error will be replaced by a placeholder.
"""
for reaction in reaction_batch:
try:
yield self._try_map_reaction_batch([reaction], detailed=detailed)[0]
except Exception as e:
logger.info(
f"Reaction causing the error: {reaction}; {e.__class__.__name__}: {e}"
)
yield {}
2 changes: 1 addition & 1 deletion rxnmapper/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def get_attention_guided_atom_maps(
absolute_product_inds: bool = False,
force_layer: Optional[int] = None,
force_head: Optional[int] = None,
):
) -> List[Dict[str, Any]]:
"""Generate atom-mapping for reactions.
Args:
Expand Down
1 change: 0 additions & 1 deletion rxnmapper/smiles_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,6 @@ def process_reaction_with_product_maps_atoms(rxn, skip_if_not_in_precursors=Fals
warnings = []

for p_map in product_atom_maps:

if skip_if_not_in_precursors and p_map not in precursors_atom_maps:
products_maps.append(-1)
elif int(p_map) == 0:
Expand Down
Empty file added tests/__init__.py
Empty file.
86 changes: 86 additions & 0 deletions tests/test_batched_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import itertools

import pytest

from rxnmapper.batched_mapper import BatchedMapper

from .utils import assert_correct_maps


@pytest.fixture(scope="module")
def batched_mapper() -> BatchedMapper:
"""
Fixture to get the RXNMapper, cached with module scope so that the weights
do not need to be loaded multiple times.
"""
return BatchedMapper(batch_size=4, canonicalize=False)


def test_normal_behavior(batched_mapper: BatchedMapper) -> None:
# Simple example with 5 reactions, given in different RXN formats
rxns = [
"CC[O-]~[Na+].BrCC>>CCOCC",
"CCC[O-]~[Na+].BrCC>>CCOCCC",
"CC[O-].[Na+].BrCC>>CCOCC |f:0.1|",
"NCC[O-]~[Na+].BrCC>>NCCOCC",
"C(C)C[O-]~[Na+].BrCC>>C(C)COCC",
]

results = batched_mapper.map_reactions(rxns)

assert list(results) == [
"[CH3:5][CH2:4][O-:3]~[Na+].Br[CH2:2][CH3:1]>>[CH3:1][CH2:2][O:3][CH2:4][CH3:5]",
"[CH3:6][CH2:5][CH2:4][O-:3]~[Na+].Br[CH2:1][CH3:2]>>[CH3:1][CH2:2][O:3][CH2:4][CH2:5][CH3:6]",
"Br[CH2:2][CH3:1].[CH3:5][CH2:4][O-:3].[Na+]>>[CH3:1][CH2:2][O:3][CH2:4][CH3:5] |f:1.2|",
"[NH2:1][CH2:2][CH2:3][O-:4]~[Na+].Br[CH2:5][CH3:6]>>[NH2:1][CH2:2][CH2:3][O:4][CH2:5][CH3:6]",
"[CH2:1]([CH3:2])[CH2:3][O-:4]~[Na+].Br[CH2:5][CH3:6]>>[CH2:1]([CH3:2])[CH2:3][O:4][CH2:5][CH3:6]",
]


def test_map_with_info(batched_mapper: BatchedMapper) -> None:
# Simple example with 5 reactions, given in different RXN formats
rxns = [
"CC[O-]~[Na+].BrCC>>CCOCC",
"CCC[O-]~[Na+].BrCC>>CCOCCC",
"CC[O-].[Na+].BrCC>>CCOCC |f:0.1|",
"NCC[O-]~[Na+].BrCC>>NCCOCC",
"C(C)C[O-]~[Na+].BrCC>>C(C)COCC",
]

for detailed in [True, False]:
results = batched_mapper.map_reactions_with_info(rxns, detailed=detailed)
results_with_original_mapper = (
batched_mapper.mapper.get_attention_guided_atom_maps(
rxns, canonicalize_rxns=False, detailed_output=detailed
)
)
assert_correct_maps(results, list(results_with_original_mapper))


def test_error(batched_mapper: BatchedMapper) -> None:
# When there is an error, the placeholder is returned instead

too_long = ".".join(itertools.repeat("ClCCl", 200)) + ".CCC[O-]~[Na+].BrCC>>CCOCCC"
invalid_symbol = "AAgCC[O-]~[Na+].BrCC>>C(C)COCC"
rxns = [
"CC[O-]~[Na+].BrCC>>CCOCC",
too_long,
"CC[O-].[Na+].BrCC>>CCOCC |f:0.1|",
"NCC[O-]~[Na+].BrCC>>NCCOCC",
invalid_symbol,
]

results = batched_mapper.map_reactions(rxns)

assert list(results) == [
"[CH3:5][CH2:4][O-:3]~[Na+].Br[CH2:2][CH3:1]>>[CH3:1][CH2:2][O:3][CH2:4][CH3:5]",
">>",
"Br[CH2:2][CH3:1].[CH3:5][CH2:4][O-:3].[Na+]>>[CH3:1][CH2:2][O:3][CH2:4][CH3:5] "
"|f:1.2|",
"[NH2:1][CH2:2][CH2:3][O-:4]~[Na+].Br[CH2:5][CH3:6]>>[NH2:1][CH2:2][CH2:3][O:4][CH2:5][CH3:6]",
">>",
]

# as comparison: RxnMapper would fail:
with pytest.raises(Exception):
batched_mapper.mapper.get_attention_guided_atom_maps([too_long])
26 changes: 8 additions & 18 deletions tests/test_mapper.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import numpy as np
import pytest

from rxnmapper import RXNMapper

from .utils import assert_correct_maps


@pytest.fixture(scope="module")
def rxn_mapper() -> RXNMapper:
Expand All @@ -13,11 +14,6 @@ def rxn_mapper() -> RXNMapper:
return RXNMapper()


def is_correct_map(result, exp):
assert result["mapped_rxn"] == exp["mapped_rxn"]
assert np.isclose(result["confidence"], exp["confidence"])


def test_example_maps(rxn_mapper: RXNMapper):
rxns = [
"CC(C)S.CN(C)C=O.Fc1cccnc1F.O=C([O-])[O-].[K+].[K+]>>CC(C)Sc1ncccc1F",
Expand All @@ -40,8 +36,7 @@ def test_example_maps(rxn_mapper: RXNMapper):
]

results = rxn_mapper.get_attention_guided_atom_maps(rxns)
for res, exp in zip(results, expected):
is_correct_map(res, exp)
assert_correct_maps(results, expected)


def test_fragment_bond(rxn_mapper: RXNMapper):
Expand All @@ -54,8 +49,7 @@ def test_fragment_bond(rxn_mapper: RXNMapper):
]

results = rxn_mapper.get_attention_guided_atom_maps(rxns)
for res, exp in zip(results, expected):
is_correct_map(res, exp)
assert_correct_maps(results, expected)


def test_extended_smiles_format(rxn_mapper: RXNMapper):
Expand All @@ -68,8 +62,7 @@ def test_extended_smiles_format(rxn_mapper: RXNMapper):
]

results = rxn_mapper.get_attention_guided_atom_maps(rxns)
for res, exp in zip(results, expected):
is_correct_map(res, exp)
assert_correct_maps(results, expected)


def test_no_canonicalization(rxn_mapper: RXNMapper):
Expand All @@ -85,8 +78,7 @@ def test_no_canonicalization(rxn_mapper: RXNMapper):
]

results = rxn_mapper.get_attention_guided_atom_maps(rxns, canonicalize_rxns=False)
for res, exp in zip(results, expected):
is_correct_map(res, exp)
assert_correct_maps(results, expected)


def test_reaction_with_invalid_valence(rxn_mapper: RXNMapper):
Expand All @@ -101,8 +93,7 @@ def test_reaction_with_invalid_valence(rxn_mapper: RXNMapper):
]

results = rxn_mapper.get_attention_guided_atom_maps(rxns, canonicalize_rxns=False)
for res, exp in zip(results, expected):
is_correct_map(res, exp)
assert_correct_maps(results, expected)


def test_multiple_products(rxn_mapper: RXNMapper):
Expand All @@ -116,5 +107,4 @@ def test_multiple_products(rxn_mapper: RXNMapper):
]

results = rxn_mapper.get_attention_guided_atom_maps(rxns)
for res, exp in zip(results, expected):
is_correct_map(res, exp)
assert_correct_maps(results, expected)
37 changes: 37 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import Any, Dict, Iterable

import numpy as np


def assert_correct_map(value_1: Dict[str, Any], value_2: Dict[str, Any]) -> None:
mandatory_keys = ["mapped_rxn", "confidence"]

# Exact matches
for key in ["mapped_rxn", "pxr_mapping_vector", "tokens"]:
if key not in mandatory_keys and key not in value_1:
continue
assert value_1[key] == value_2[key]

# close match on single number
for key in ["confidence"]:
if key not in mandatory_keys and key not in value_1:
continue
assert np.isclose(value_1[key], value_2[key])

# close match on multiple values
for key in [
"pxr_confidences",
"pxrrxp_attns",
"tokensxtokens_attns",
"mapping_tuples",
]:
if key not in mandatory_keys and key not in value_1:
continue
assert np.allclose(value_1[key], value_2[key], rtol=1e-4, atol=1e-7)


def assert_correct_maps(
values_1: Iterable[Dict[str, Any]], values_2: Iterable[Dict[str, Any]]
) -> None:
for value_1, value_2 in zip(values_1, values_2):
assert_correct_map(value_1, value_2)

0 comments on commit df354df

Please sign in to comment.