-
Notifications
You must be signed in to change notification settings - Fork 72
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Running RxnMapper in batches and without raising errors (#36)
- Loading branch information
Showing
9 changed files
with
251 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
import logging | ||
from typing import Any, Dict, Iterable, Iterator, List | ||
|
||
from rxn.utilities.containers import chunker | ||
|
||
from .core import RXNMapper | ||
|
||
logger = logging.getLogger(__name__) | ||
logger.addHandler(logging.NullHandler()) | ||
|
||
# Alias for what the original mapper returns | ||
ResultWithInfo = Dict[str, Any] | ||
|
||
|
||
class BatchedMapper: | ||
""" | ||
Class to atom-map reactions in batches, with error control. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
batch_size: int, | ||
canonicalize: bool = False, | ||
placeholder_for_invalid: str = ">>", | ||
): | ||
self.mapper = RXNMapper() | ||
self.batch_size = batch_size | ||
self.canonicalize = canonicalize | ||
self.placeholder_for_invalid = placeholder_for_invalid | ||
|
||
def map_reactions(self, reaction_smiles: Iterable[str]) -> Iterator[str]: | ||
"""Map the given reactions, returning the mapped SMILES strings. | ||
Args: | ||
reaction_smiles: reaction SMILES strings to map. | ||
Returns: | ||
iterator over mapped strings; a placeholder is returned for the | ||
entries that failed. | ||
""" | ||
for result in self.map_reactions_with_info(reaction_smiles): | ||
if result == {}: | ||
yield self.placeholder_for_invalid | ||
else: | ||
yield result["mapped_rxn"] | ||
|
||
def map_reactions_with_info( | ||
self, reaction_smiles: Iterable[str], detailed: bool = False | ||
) -> Iterator[ResultWithInfo]: | ||
"""Map the given reactions, returning the results as dictionaries. | ||
Args: | ||
reaction_smiles: reaction SMILES strings to map. | ||
detailed: detailed output or not. | ||
Returns: | ||
iterator over dictionaries (in the format returned by the RXNMapper class); | ||
an empty dictionary is returned for the entries that failed. | ||
""" | ||
for rxns_chunk in chunker(reaction_smiles, chunk_size=self.batch_size): | ||
yield from self._map_reaction_batch(rxns_chunk, detailed=detailed) | ||
|
||
def _map_reaction_batch( | ||
self, reaction_batch: List[str], detailed: bool | ||
) -> Iterator[ResultWithInfo]: | ||
try: | ||
yield from self._try_map_reaction_batch(reaction_batch, detailed=detailed) | ||
except Exception: | ||
logger.warning( | ||
f"Error while mapping chunk of {len(reaction_batch)} reactions. " | ||
"Mapping them individually." | ||
) | ||
yield from self._map_reactions_one_by_one(reaction_batch, detailed=detailed) | ||
|
||
def _try_map_reaction_batch( | ||
self, reaction_batch: List[str], detailed: bool | ||
) -> List[ResultWithInfo]: | ||
""" | ||
Map a reaction batch, without error handling. | ||
Note: we return a list, not a generator function, to avoid returning partial | ||
results. | ||
""" | ||
return self.mapper.get_attention_guided_atom_maps( | ||
reaction_batch, | ||
canonicalize_rxns=self.canonicalize, | ||
detailed_output=detailed, | ||
) | ||
|
||
def _map_reactions_one_by_one( | ||
self, reaction_batch: Iterable[str], detailed: bool | ||
) -> Iterator[ResultWithInfo]: | ||
""" | ||
Map a reaction batch, one reaction at a time. | ||
Reactions causing an error will be replaced by a placeholder. | ||
""" | ||
for reaction in reaction_batch: | ||
try: | ||
yield self._try_map_reaction_batch([reaction], detailed=detailed)[0] | ||
except Exception as e: | ||
logger.info( | ||
f"Reaction causing the error: {reaction}; {e.__class__.__name__}: {e}" | ||
) | ||
yield {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import itertools | ||
|
||
import pytest | ||
|
||
from rxnmapper.batched_mapper import BatchedMapper | ||
|
||
from .utils import assert_correct_maps | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def batched_mapper() -> BatchedMapper: | ||
""" | ||
Fixture to get the RXNMapper, cached with module scope so that the weights | ||
do not need to be loaded multiple times. | ||
""" | ||
return BatchedMapper(batch_size=4, canonicalize=False) | ||
|
||
|
||
def test_normal_behavior(batched_mapper: BatchedMapper) -> None: | ||
# Simple example with 5 reactions, given in different RXN formats | ||
rxns = [ | ||
"CC[O-]~[Na+].BrCC>>CCOCC", | ||
"CCC[O-]~[Na+].BrCC>>CCOCCC", | ||
"CC[O-].[Na+].BrCC>>CCOCC |f:0.1|", | ||
"NCC[O-]~[Na+].BrCC>>NCCOCC", | ||
"C(C)C[O-]~[Na+].BrCC>>C(C)COCC", | ||
] | ||
|
||
results = batched_mapper.map_reactions(rxns) | ||
|
||
assert list(results) == [ | ||
"[CH3:5][CH2:4][O-:3]~[Na+].Br[CH2:2][CH3:1]>>[CH3:1][CH2:2][O:3][CH2:4][CH3:5]", | ||
"[CH3:6][CH2:5][CH2:4][O-:3]~[Na+].Br[CH2:1][CH3:2]>>[CH3:1][CH2:2][O:3][CH2:4][CH2:5][CH3:6]", | ||
"Br[CH2:2][CH3:1].[CH3:5][CH2:4][O-:3].[Na+]>>[CH3:1][CH2:2][O:3][CH2:4][CH3:5] |f:1.2|", | ||
"[NH2:1][CH2:2][CH2:3][O-:4]~[Na+].Br[CH2:5][CH3:6]>>[NH2:1][CH2:2][CH2:3][O:4][CH2:5][CH3:6]", | ||
"[CH2:1]([CH3:2])[CH2:3][O-:4]~[Na+].Br[CH2:5][CH3:6]>>[CH2:1]([CH3:2])[CH2:3][O:4][CH2:5][CH3:6]", | ||
] | ||
|
||
|
||
def test_map_with_info(batched_mapper: BatchedMapper) -> None: | ||
# Simple example with 5 reactions, given in different RXN formats | ||
rxns = [ | ||
"CC[O-]~[Na+].BrCC>>CCOCC", | ||
"CCC[O-]~[Na+].BrCC>>CCOCCC", | ||
"CC[O-].[Na+].BrCC>>CCOCC |f:0.1|", | ||
"NCC[O-]~[Na+].BrCC>>NCCOCC", | ||
"C(C)C[O-]~[Na+].BrCC>>C(C)COCC", | ||
] | ||
|
||
for detailed in [True, False]: | ||
results = batched_mapper.map_reactions_with_info(rxns, detailed=detailed) | ||
results_with_original_mapper = ( | ||
batched_mapper.mapper.get_attention_guided_atom_maps( | ||
rxns, canonicalize_rxns=False, detailed_output=detailed | ||
) | ||
) | ||
assert_correct_maps(results, list(results_with_original_mapper)) | ||
|
||
|
||
def test_error(batched_mapper: BatchedMapper) -> None: | ||
# When there is an error, the placeholder is returned instead | ||
|
||
too_long = ".".join(itertools.repeat("ClCCl", 200)) + ".CCC[O-]~[Na+].BrCC>>CCOCCC" | ||
invalid_symbol = "AAgCC[O-]~[Na+].BrCC>>C(C)COCC" | ||
rxns = [ | ||
"CC[O-]~[Na+].BrCC>>CCOCC", | ||
too_long, | ||
"CC[O-].[Na+].BrCC>>CCOCC |f:0.1|", | ||
"NCC[O-]~[Na+].BrCC>>NCCOCC", | ||
invalid_symbol, | ||
] | ||
|
||
results = batched_mapper.map_reactions(rxns) | ||
|
||
assert list(results) == [ | ||
"[CH3:5][CH2:4][O-:3]~[Na+].Br[CH2:2][CH3:1]>>[CH3:1][CH2:2][O:3][CH2:4][CH3:5]", | ||
">>", | ||
"Br[CH2:2][CH3:1].[CH3:5][CH2:4][O-:3].[Na+]>>[CH3:1][CH2:2][O:3][CH2:4][CH3:5] " | ||
"|f:1.2|", | ||
"[NH2:1][CH2:2][CH2:3][O-:4]~[Na+].Br[CH2:5][CH3:6]>>[NH2:1][CH2:2][CH2:3][O:4][CH2:5][CH3:6]", | ||
">>", | ||
] | ||
|
||
# as comparison: RxnMapper would fail: | ||
with pytest.raises(Exception): | ||
batched_mapper.mapper.get_attention_guided_atom_maps([too_long]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from typing import Any, Dict, Iterable | ||
|
||
import numpy as np | ||
|
||
|
||
def assert_correct_map(value_1: Dict[str, Any], value_2: Dict[str, Any]) -> None: | ||
mandatory_keys = ["mapped_rxn", "confidence"] | ||
|
||
# Exact matches | ||
for key in ["mapped_rxn", "pxr_mapping_vector", "tokens"]: | ||
if key not in mandatory_keys and key not in value_1: | ||
continue | ||
assert value_1[key] == value_2[key] | ||
|
||
# close match on single number | ||
for key in ["confidence"]: | ||
if key not in mandatory_keys and key not in value_1: | ||
continue | ||
assert np.isclose(value_1[key], value_2[key]) | ||
|
||
# close match on multiple values | ||
for key in [ | ||
"pxr_confidences", | ||
"pxrrxp_attns", | ||
"tokensxtokens_attns", | ||
"mapping_tuples", | ||
]: | ||
if key not in mandatory_keys and key not in value_1: | ||
continue | ||
assert np.allclose(value_1[key], value_2[key], rtol=1e-4, atol=1e-7) | ||
|
||
|
||
def assert_correct_maps( | ||
values_1: Iterable[Dict[str, Any]], values_2: Iterable[Dict[str, Any]] | ||
) -> None: | ||
for value_1, value_2 in zip(values_1, values_2): | ||
assert_correct_map(value_1, value_2) |