Skip to content

Commit

Permalink
Bug fixes and improvements (#24)
Browse files Browse the repository at this point in the history
* Better handling of compounds with the help of rxn-chemutils

* additional updates

* start list with differences

* refactor function

* improvement

* one less canonicalization

* updates

* temporary tests

* Styling

* Update tests, avoid sanitizing when not necessary

* Simplify script

* Add test

* Additional fixes

* Use fixture instead of variable

* Remove test

* np.bool deprecation

* Reset README

* black

* fixes

* Add back tqdm

* Fix number batches

* remove print

* update version

* improve doc
  • Loading branch information
avaucher authored Aug 4, 2022
1 parent 905d283 commit 5c8d015
Show file tree
Hide file tree
Showing 7 changed files with 292 additions and 133 deletions.
4 changes: 2 additions & 2 deletions rxnmapper/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __init__(
# Mask of atoms
self.atom_token_mask = np.array(
get_mask_for_tokens(self.tokens, self.special_tokens)
).astype(np.bool)
).astype(bool)

# Atoms numbered in the array
self.token2atom = np.array(number_tokens(tokens))
Expand All @@ -78,7 +78,7 @@ def __init__(
}

# Adjacency graph for all tokens
self.adjacency_matrix = tokens_to_adjacency(tokens).astype(np.bool)
self.adjacency_matrix = tokens_to_adjacency(tokens).astype(bool)

self._precursors_atom_types = None
self._product_atom_types = None
Expand Down
76 changes: 66 additions & 10 deletions rxnmapper/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,18 @@

import logging
import os
from typing import Dict, List, Optional, Tuple
from typing import Any, Dict, Iterator, List, Optional, Tuple

import numpy as np
import pkg_resources
import torch
from rxn.chemutils.reaction_equation import ReactionEquation
from rxn.chemutils.reaction_smiles import (
ReactionFormat,
determine_format,
parse_any_reaction_smiles,
to_reaction_smiles,
)
from transformers import AlbertModel, BertModel, RobertaModel

from .attention import AttentionScorer
Expand All @@ -16,7 +23,8 @@

MODEL_TYPE_DICT = {"bert": BertModel, "albert": AlbertModel, "roberta": RobertaModel}

LOGGER = logging.getLogger("rxnmapper:core")
_logger = logging.getLogger(__name__)
_logger.addHandler(logging.NullHandler())


class RXNMapper:
Expand All @@ -28,7 +36,7 @@ class RXNMapper:

def __init__(
self,
config: Dict = {},
config: Optional[Dict[str, Any]] = None,
logger: Optional[logging.Logger] = None,
):
"""
Expand All @@ -45,6 +53,8 @@ def __init__(
>>> from rxnmapper import RXNMapper
>>> rxn_mapper = RXNMapper()
"""
if config is None:
config = {}

# Config takes "model_path", "model_type", "attention_multiplier", "head", "layers"
self.model_path = config.get(
Expand All @@ -58,7 +68,7 @@ def __init__(
self.head = config.get("head", 5)
self.layers = config.get("layers", [10])

self.logger = logger if logger else LOGGER
self.logger = logger if logger else _logger
self.model, self.tokenizer = self._load_model_and_tokenizer()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
Expand Down Expand Up @@ -178,10 +188,57 @@ def get_attention_guided_atom_maps(
- tokensxtokens_attns: Full attentions for all tokens
- tokens: Tokens that were inputted into the model
"""
reaction_formats = (determine_format(rxn) for rxn in rxns)
reactions = [parse_any_reaction_smiles(rxn) for rxn in rxns]

raw_results = self.get_attention_guided_atom_maps_for_reactions(
reactions=reactions,
zero_set_p=zero_set_p,
zero_set_r=zero_set_r,
canonicalize_rxns=canonicalize_rxns,
detailed_output=detailed_output,
absolute_product_inds=absolute_product_inds,
force_layer=force_layer,
force_head=force_head,
)

results = []
for (reaction, result), reaction_format in zip(raw_results, reaction_formats):
mapped_rxn = to_reaction_smiles(reaction, reaction_format=reaction_format)
result["mapped_rxn"] = mapped_rxn
results.append(result)
return results

def get_attention_guided_atom_maps_for_reactions(
self,
reactions: List[ReactionEquation],
zero_set_p: bool = True,
zero_set_r: bool = True,
canonicalize_rxns: bool = True,
detailed_output: bool = False,
absolute_product_inds: bool = False,
force_layer: Optional[int] = None,
force_head: Optional[int] = None,
) -> Iterator[Tuple[ReactionEquation, Dict[str, Any]]]:
"""Generate atom-mapping for ReactionEquation instances.
See documentation of get_attention_guided_atom_maps() for details on the
arguments and return value. The only difference is that the mapped reaction
is returned as a ReactionEquation, which is added to the dictionary
outside of this function depending on the required format.
"""

if canonicalize_rxns:
rxns = [process_reaction(rxn) for rxn in rxns]
reactions = [process_reaction(reaction) for reaction in reactions]

# The transformer has been trained on the format containing tildes.
# This means that we must convert to that format for use with the model.
rxns = [
to_reaction_smiles(
reaction, reaction_format=ReactionFormat.STANDARD_WITH_TILDE
)
for reaction in reactions
]

attns = self.convert_batch_to_attns(
rxns, force_layer=force_layer, force_head=force_head
Expand All @@ -204,10 +261,10 @@ def get_attention_guided_atom_maps(
absolute_product_inds=absolute_product_inds
)

mapped_reaction = generate_atom_mapped_reaction_atoms(
rxn, output["pxr_mapping_vector"], canonical=canonicalize_rxns
)
result = {
"mapped_rxn": generate_atom_mapped_reaction_atoms(
rxn, output["pxr_mapping_vector"]
),
"confidence": np.prod(output["confidences"]),
}
if detailed_output:
Expand All @@ -218,5 +275,4 @@ def get_attention_guided_atom_maps(
result["tokensxtokens_attns"] = tokensxtokens_attn
result["tokens"] = just_tokens

results.append(result)
return results
yield mapped_reaction, result
Loading

0 comments on commit 5c8d015

Please sign in to comment.