Skip to content

Commit

Permalink
Add mypy checks and typing support (#29)
Browse files Browse the repository at this point in the history
  • Loading branch information
avaucher authored Oct 11, 2022
1 parent c5e8da8 commit 62210e5
Show file tree
Hide file tree
Showing 8 changed files with 99 additions and 69 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,7 @@ jobs:
run: python -m isort --check --diff .
- name: Check flake8
run: python -m flake8 .
- name: Run mypy
run: python -m mypy .
- name: Run pytests
run: python -m pytest
3 changes: 2 additions & 1 deletion docs_source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#
import os
import sys
from typing import List

sys.path.insert(0, os.path.abspath("."))

Expand Down Expand Up @@ -43,7 +44,7 @@
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = []
exclude_patterns: List[str] = []

# -- Options for HTML output -------------------------------------------------

Expand Down
16 changes: 16 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,21 @@
requires = ["setuptools >= 59.2.0", "wheel"]
build-backend = "setuptools.build_meta"

[tool.mypy]
check_untyped_defs = true
plugins = [
"numpy.typing.mypy_plugin",
]

[[tool.mypy.overrides]]
module = [
"pandas.*",
"rdkit.*",
"scipy.*",
"tqdm.*",
"transformers.*",
]
ignore_missing_imports = true

[tool.isort]
profile = "black"
92 changes: 46 additions & 46 deletions rxnmapper/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
domain / in the token domain, and accounting for adjacent atoms in molecules.
"""
import logging
from typing import List
from typing import List, Optional

import numpy as np

Expand Down Expand Up @@ -80,39 +80,39 @@ def __init__(
# Adjacency graph for all tokens
self.adjacency_matrix = tokens_to_adjacency(tokens).astype(bool)

self._precursors_atom_types = None
self._product_atom_types = None
self._rnums_atoms = None
self._pnums_atoms = None
self._nreactant_atoms = None
self._nproduct_atoms = None
self._adjacency_matrix_products = None
self._adjacency_matrix_precursors = None
self._pxr_filt_atoms = None
self._rxp_filt_atoms = None
self._atom_type_mask = None
self._atom_type_masked_attentions = None
self._precursors_atom_types: Optional[List[int]] = None
self._product_atom_types: Optional[List[int]] = None
self._rnums_atoms: Optional[np.ndarray] = None
self._pnums_atoms: Optional[np.ndarray] = None
self._nreactant_atoms: Optional[int] = None
self._nproduct_atoms: Optional[int] = None
self._adjacency_matrix_products: Optional[np.ndarray] = None
self._adjacency_matrix_precursors: Optional[np.ndarray] = None
self._pxr_filt_atoms: Optional[np.ndarray] = None
self._rxp_filt_atoms: Optional[np.ndarray] = None
self._atom_type_mask: Optional[np.ndarray] = None
self._atom_type_masked_attentions: Optional[np.ndarray] = None

# Attention multiplication matrix
self.attention_multiplier_matrix = np.ones_like(
self.combined_attentions_filt_atoms
).astype(float)

@property
def atom_attentions(self):
def atom_attentions(self) -> np.ndarray:
"""The MxM attention matrix, selected for only attentions that are from atoms, to atoms"""
return self.attentions[self.atom_token_mask].T[self.atom_token_mask].T

@property
def adjacent_atom_attentions(self):
def adjacent_atom_attentions(self) -> np.ndarray:
"""The MxM attention matrix, where all attentions are zeroed if the attention is not to an adjacent atom."""
atts = self.atom_attentions.copy()
mask = np.logical_not(self.adjacency_matrix)
atts[mask] = 0
return atts

@property
def adjacency_matrix_reactants(self):
def adjacency_matrix_reactants(self) -> np.ndarray:
"""The adjacency matrix of the reactants"""
if self._adjacency_matrix_precursors is None:
self._adjacency_matrix_precursors = self.adjacency_matrix[
Expand All @@ -121,7 +121,7 @@ def adjacency_matrix_reactants(self):
return self._adjacency_matrix_precursors

@property
def adjacency_matrix_products(self):
def adjacency_matrix_products(self) -> np.ndarray:
"""The adjacency matrix of the products"""
if self._adjacency_matrix_products is None:
self._adjacency_matrix_products = self.adjacency_matrix[
Expand All @@ -130,7 +130,7 @@ def adjacency_matrix_products(self):
return self._adjacency_matrix_products

@property
def atom_type_masked_attentions(self):
def atom_type_masked_attentions(self) -> np.ndarray:
"""Generate a"""
if self._atom_type_masked_attentions is None:
self._atom_type_masked_attentions = np.multiply(
Expand All @@ -139,17 +139,17 @@ def atom_type_masked_attentions(self):
return self._atom_type_masked_attentions

@property
def rxp(self):
def rxp(self) -> np.ndarray:
"""Subset of attentions relating the reactants to the products"""
return self.attentions[: self.split_ind, (self.split_ind + 1) :]

@property
def rxp_filt(self):
def rxp_filt(self) -> np.ndarray:
"""RXP without the special tokens"""
return self.rxp[1:, :-1]

@property
def rxp_filt_atoms(self):
def rxp_filt_atoms(self) -> np.ndarray:
"""RXP only the atoms, no special tokens"""
if self._rxp_filt_atoms is None:
self._rxp_filt_atoms = self.rxp[[i != -1 for i in self.rnums]][
Expand All @@ -158,18 +158,18 @@ def rxp_filt_atoms(self):
return self._rxp_filt_atoms

@property
def pxr(self):
def pxr(self) -> np.ndarray:
"""Subset of attentions relating the products to the reactants"""
i = self.split_ind
return self.attentions[(i + 1) :, :i]

@property
def pxr_filt(self):
def pxr_filt(self) -> np.ndarray:
"""PXR without the special tokens"""
return self.pxr[:-1, 1:]

@property
def pxr_filt_atoms(self):
def pxr_filt_atoms(self) -> np.ndarray:
"""PXR only the atoms, no special tokens"""
if self._pxr_filt_atoms is None:
self._pxr_filt_atoms = self.pxr[[i != -1 for i in self.pnums]][
Expand All @@ -178,22 +178,22 @@ def pxr_filt_atoms(self):
return self._pxr_filt_atoms

@property
def combined_attentions(self):
def combined_attentions(self) -> np.ndarray:
"""Summed pxr and rxp"""
return self.pxr + self.rxp.T

@property
def combined_attentions_filt(self):
def combined_attentions_filt(self) -> np.ndarray:
"""Summed pxr_filt and rxp_filt (no special tokens)"""
return self.pxr_filt + self.rxp_filt.T

@property
def combined_attentions_filt_atoms(self):
def combined_attentions_filt_atoms(self) -> np.ndarray:
"""Summed pxr_filt_atoms and rxp_filt_atoms (no special tokens, no "non-atom" tokens)"""
return self.pxr_filt_atoms + self.rxp_filt_atoms.T

@property
def combined_attentions_filt_atoms_same_type(self):
def combined_attentions_filt_atoms_same_type(self) -> np.ndarray:
"""Summed pxr_filt_atoms and rxp_filt_atoms (no special tokens, no "non-atom" tokens). All attentions to atoms of a different type are zeroed"""

atom_type_mask = np.zeros(self.combined_attentions_filt_atoms.shape)
Expand All @@ -218,23 +218,23 @@ def combined_attentions_filt_atoms_same_type(self):
return normalized_attentions

@property
def pnums(self):
def pnums(self) -> np.ndarray:
"""Get atom indexes for just the product tokens.
Numbers in this vector that are >= 0 are atoms, whereas indexes == -1 represent special tokens (e.g., bonds, parens, [CLS])
"""
return self.token2atom[(self.split_ind + 1) :]

@property
def pnums_filt(self):
def pnums_filt(self) -> np.ndarray:
"""Get atom indexes for just the product tokens, without the [SEP].
Numbers in this vector that are >= 0 are atoms, whereas indexes == -1 represent special tokens (e.g., bonds, parens, [CLS])
"""
return self.pnums

@property
def pnums_atoms(self):
def pnums_atoms(self) -> np.ndarray:
"""Get atom indexes for just the product ATOMS, without the [SEP].
Numbers in this vector that are >= 0 are atoms, whereas indexes == -1 represent special tokens (e.g., bonds, parens, [CLS])
Expand All @@ -244,23 +244,23 @@ def pnums_atoms(self):
return self._pnums_atoms

@property
def rnums(self):
def rnums(self) -> np.ndarray:
"""Get atom indexes for the reactant tokens.
Numbers in this vector that are >= 0 are atoms, whereas indexes == -1 represent special tokens (e.g., bonds, parens, [CLS])
"""
return self.token2atom[: self.split_ind]

@property
def rnums_filt(self):
def rnums_filt(self) -> np.ndarray:
"""Get atom indexes for just the reactant tokens, without the [CLS].
Numbers in this vector that are >= 0 are atoms, whereas indexes == -1 represent special tokens (e.g., bonds, parens, [CLS])
"""
return self.rnums[1:]

@property
def rnums_atoms(self):
def rnums_atoms(self) -> np.ndarray:
"""Get atom indexes for the reactant ATOMS, without the [CLS].
Numbers in this vector that are >= 0 are atoms, whereas indexes == -1 represent special tokens (e.g., bonds, parens, [CLS])
Expand All @@ -270,38 +270,38 @@ def rnums_atoms(self):
return self._rnums_atoms

@property
def nreactant_atoms(self):
def nreactant_atoms(self) -> int:
"""The number of atoms in the reactants"""
if self._nreactant_atoms is None:
self._nreactant_atoms = len(self.rnums_atoms)

return self._nreactant_atoms

@property
def nproduct_atoms(self):
def nproduct_atoms(self) -> int:
"""The number of atoms in the product"""
if self._nproduct_atoms is None:
self._nproduct_atoms = len(self.pnums_atoms)

return self._nproduct_atoms

@property
def rtokens(self):
def rtokens(self) -> List[str]:
"""Just the reactant tokens"""
return self.tokens[self._reactant_inds]

@property
def rtokens_filt(self):
def rtokens_filt(self) -> List[str]:
"""Reactant tokens without special tokens"""
return self.rtokens[1:]

@property
def ptokens(self):
def ptokens(self) -> List[str]:
"""Just the product tokens"""
return self.tokens[self._product_inds]

@property
def ptokens_filt(self):
def ptokens_filt(self) -> List[str]:
"""Product tokens without special tokens"""
return self.ptokens[:-1]

Expand Down Expand Up @@ -331,21 +331,21 @@ def get_neighboring_atoms(self, atom_num):
"""Get the atom indexes neighboring the desired atom"""
return np.nonzero(self.adjacency_matrix[atom_num])[0]

def get_precursors_atom_types(self):
def get_precursors_atom_types(self) -> List[int]:
"""Convert reactants into their atomic numbers"""
if self._precursors_atom_types is None:
self._precursors_atom_types = get_atom_types_smiles(
"".join(self.rtokens[1:])
)
return self._precursors_atom_types

def get_product_atom_types(self):
def get_product_atom_types(self) -> List[int]:
"""Convert products into their atomic indexes"""
if self._product_atom_types is None:
self._product_atom_types = get_atom_types_smiles("".join(self.ptokens[:-1]))
return self._product_atom_types

def get_atom_type_mask(self):
def get_atom_type_mask(self) -> np.ndarray:
"""Return a mask where only atoms of the same type are True"""
if self._atom_type_mask is None:
atom_type_mask = np.zeros(self.combined_attentions_filt_atoms.shape)
Expand Down Expand Up @@ -430,7 +430,7 @@ def generate_attention_guided_pxr_atom_mapping(
return output

def _update_attention_multiplier_matrix(
self, product_atom: int, reactant_atom: int
self, product_atom: np.signedinteger, reactant_atom: int
):
"""Perform the "neighbor multiplier" step of the atom mapping
Expand Down Expand Up @@ -458,9 +458,9 @@ def _update_attention_multiplier_matrix(
len(self.pnums_atoms)
)

def __len__(self):
def __len__(self) -> int:
"""Length of provided tokens"""
return len(self.tokens)

def __repr__(self):
def __repr__(self) -> str:
return f"AttMapper(`{self.rxn[:50]}...`)"
9 changes: 3 additions & 6 deletions rxnmapper/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def convert_batch_to_attns(
rxn_smiles_list: List[str],
force_layer: Optional[int] = None,
force_head: Optional[int] = None,
):
) -> List[torch.Tensor]:
"""Extract desired attentions from a given batch of reactions.
Args:
Expand Down Expand Up @@ -139,10 +139,7 @@ def convert_batch_to_attns(
selected_attns = torch.mean(selected_attns, dim=[1])
att_masks = encoded_ids["attention_mask"].to(torch.bool)

selected_attns = [
a[mask][:, mask] for a, mask in zip(selected_attns, att_masks)
]
return selected_attns
return [a[mask][:, mask] for a, mask in zip(selected_attns, att_masks)]

def tokenize_for_model(self, rxn: str):
"""Tokenize a reaction SMILES with the special tokens needed for the model"""
Expand Down Expand Up @@ -261,7 +258,7 @@ def get_attention_guided_atom_maps_for_reactions(
absolute_product_inds=absolute_product_inds
)

mapped_reaction = generate_atom_mapped_reaction_atoms(
mapped_reaction, _ = generate_atom_mapped_reaction_atoms(
rxn, output["pxr_mapping_vector"], canonical=canonicalize_rxns
)
result = {
Expand Down
Empty file added rxnmapper/py.typed
Empty file.
Loading

0 comments on commit 62210e5

Please sign in to comment.