Skip to content

Commit

Permalink
removed try statement, catch invalid rxn str
Browse files Browse the repository at this point in the history
  • Loading branch information
dswigh committed Apr 14, 2023
1 parent f210b44 commit a5e7bcc
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 49 deletions.
81 changes: 38 additions & 43 deletions orderly/extract/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,22 +148,23 @@ def get_rxn_string_and_is_mapped(
rxn_str = rxn_str_extended_smiles.split(" ")[
0
] # this is to get rid of the extended smiles info
return RXN_STR(rxn_str), is_mapped

count = rxn_str.count(">")
if count == 2: # Finally, we need to check whether the reaction string is valid
return RXN_STR(rxn_str), is_mapped
else:
return None

@staticmethod
def extract_info_from_rxn(
rxn: ord_reaction_pb2.Reaction,
def extract_info_from_rxn_str(
rxn_str: RXN_STR, is_mapped: bool
) -> Optional[
Tuple[REACTANTS, AGENTS, PRODUCTS, RXN_STR, List[MOLECULE_IDENTIFIER]]
]:
"""
Input a reaction object, and return the reactants, agents, products, and the reaction smiles string
"""
_ = rdkit_BlockLogs()
_rxn_str = OrdExtractor.get_rxn_string_and_is_mapped(rxn)
if _rxn_str is None:
return None
rxn_str, is_mapped = _rxn_str

reactant_from_rxn, agent, product_from_rxn = rxn_str.split(">")

Expand Down Expand Up @@ -585,7 +586,8 @@ def handle_reaction_object(
else:
rxn_str, is_mapped = _rxn_str

if trust_labelling:
# Get all the molecules
if trust_labelling or (rxn_str is None and use_labelling_if_extract_fails):
reactants = labelled_reactants
products = labelled_products
yields = yields
Expand All @@ -594,38 +596,29 @@ def handle_reaction_object(
reagents = labelled_reagents
catalysts = labelled_catalysts
is_mapped = False
elif (
(not trust_labelling)
and (rxn_str is None)
and (not use_labelling_if_extract_fails)
):
return None
else:
# TODO we should remove this try block
try: # to extract info from the reaction string
rxn_info = OrdExtractor.extract_info_from_rxn(rxn)
if rxn_info is None:
raise ValueError("rxn_info is None")
(
rxn_str_reactants,
rxn_str_agents,
rxn_str_products,
rxn_str,
rxn_non_smiles_names_list,
) = rxn_info
reactants = list(set(rxn_str_reactants))
# Resolve: yields are from rxn_outcomes, but we trust the products from the rxn_string
rxn_str_products = list(set(rxn_str_products))
products, _yields = OrdExtractor.match_yield_with_product(
rxn_str_products, labelled_products, yields
)
if _yields is None:
_yields = []
yields = _yields

except (ValueError, TypeError) as e:
rxn_str_agents = []
# ValueError is raised when rxn_info is None, or when it's invalid, e.g. if the rxn_string only has one >. Rxn strings should have 2, e.g. A>B>C
# TypeError is raised when rxn_str is neither None nor a string (this should be impossible though due to the schema!)
if use_labelling_if_extract_fails:
reactants = labelled_reactants
products = labelled_products
else:
return None
# extract info from the reaction string
rxn_info = OrdExtractor.extract_info_from_rxn_str(rxn_str, is_mapped)
(
reactants,
agents,
_products,
rxn_str,
rxn_non_smiles_names_list,
) = rxn_info
# Resolve: yields are from rxn_outcomes, but we trust the products from the rxn_string
products, _yields = OrdExtractor.match_yield_with_product(
_products, labelled_products, yields
)
if _yields is None:
_yields = []
yields = _yields

if (
include_unadded_labelled_agents
Expand All @@ -644,20 +637,22 @@ def handle_reaction_object(
molecules_unique_to_labelled_data = [
x
for x in all_labelled_molecules
if x not in reactants + rxn_str_agents + solvents + products
if x not in reactants + agents + solvents + products
]
rxn_str_agents += molecules_unique_to_labelled_data
agents += molecules_unique_to_labelled_data

if trust_labelling == False:
# Merge conditions
agents, solvents = OrdExtractor.merge_to_agents(
rxn_str_agents,
agents,
labelled_catalysts,
labelled_solvents,
labelled_reagents,
solvents_set,
)
reagents = []
catalysts = []

# extract temperature
temperature = OrdExtractor.temperature_extractor(rxn)

Expand Down Expand Up @@ -703,7 +698,7 @@ def canonicalise_and_get_non_smiles_names(
return mole_id_list, non_smiles_names_list_additions

# Reactants and products might be mapped, but agents are not
# TODO?: The canonicalisation is repeated! We extract information from rxn_str, and then apply logic to figure out what is a reactant/agent. So we canonicalise inside the extract_info_from_rxn function, but not within the input_extraction function, which is why we need to do it again here. This also means we add stuff to the non-smiles names list multiple times, so we need to do list(set()) on that list; all this is slightly inefficient, but shouldn't add that much overhead.
# TODO?: The canonicalisation is repeated! We extract information from rxn_str, and then apply logic to figure out what is a reactant/agent. So we canonicalise inside the extract_info_from_rxn_str function, but not within the input_extraction function, which is why we need to do it again here. This also means we add stuff to the non-smiles names list multiple times, so we need to do list(set()) on that list; all this is slightly inefficient, but shouldn't add that much overhead.
(
reactants,
non_smiles_names_list_additions,
Expand Down
16 changes: 10 additions & 6 deletions tests/test_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,16 +344,13 @@ def test_rxn_string_and_is_mapped(
["CCC"],
"CC.C>CCC",
[],
False,
marks=pytest.mark.xfail(
reason="ValueError: not enough values to unpack (expected 3, got 2)"
),
True,
),
# There's no point in trying to test whether the the rxn.identifiers[0].value = None because the schema doesn't allow that overwrite to happen!
),
)
@pytest.mark.parametrize("execution_number", range(REPETITIONS))
def test_extract_info_from_rxn(
def test_extract_info_from_rxn_str(
execution_number: int,
file_name: str,
rxn_idx: int,
Expand All @@ -372,7 +369,14 @@ def test_extract_info_from_rxn(

import orderly.extract.extractor

rxn_info = orderly.extract.extractor.OrdExtractor.extract_info_from_rxn(rxn)
_rxn_info = orderly.extract.extractor.OrdExtractor.get_rxn_string_and_is_mapped(rxn)
if _rxn_info is None:
return None
rxn_str, is_mapped = _rxn_info

rxn_info = orderly.extract.extractor.OrdExtractor.extract_info_from_rxn_str(
rxn_str, is_mapped
)
if expected_none:
assert rxn_info is None, f"expected a none but got {rxn_info=}"
return
Expand Down

0 comments on commit a5e7bcc

Please sign in to comment.