diff --git a/chemcrow/tools/safety.py b/chemcrow/tools/safety.py index 87bd10d..fbe8cb7 100644 --- a/chemcrow/tools/safety.py +++ b/chemcrow/tools/safety.py @@ -11,73 +11,41 @@ from langchain import LLMChain, PromptTemplate from langchain.llms import BaseLLM from langchain.tools import BaseTool -from rdkit import Chem from chemcrow.utils import * -from chemcrow.utils import is_smiles, tanimoto +from chemcrow.utils import ( + is_multiple_smiles, + is_smiles, + query2cas, + query2smiles, + split_smiles, + tanimoto, +) from .prompts import safety_summary_prompt, summary_each_data -def query2smiles( - query: str, - url: str = "https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{}/{}", -) -> str: - if url is None: - url = "https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{}/{}" - r = requests.get(url.format(query, "property/IsomericSMILES/JSON")) - # convert the response to a json object - data = r.json() - # return the SMILES string - try: - smi = data["PropertyTable"]["Properties"][0]["IsomericSMILES"] - except KeyError: - return "Could not find a molecule matching the text. One possible cause is that the input is incorrect, input one molecule at a time." - return str(Chem.CanonSmiles(largest_mol(smi))) - - -def query2cas(query: str, url_cid: str, url_data: str): - try: - mode = "name" - if is_smiles(query): - mode = "smiles" - url_cid = url_cid.format(mode, query) - cid = requests.get(url_cid).json()["IdentifierList"]["CID"][0] - url_data = url_data.format(cid) - data = requests.get(url_data).json() - except (requests.exceptions.RequestException, KeyError): - raise ValueError("Invalid molecule input, no Pubchem entry") - - try: - for section in data["Record"]["Section"]: - if section.get("TOCHeading") == "Names and Identifiers": - for subsection in section["Section"]: - if subsection.get("TOCHeading") == "Other Identifiers": - for subsubsection in subsection["Section"]: - if subsubsection.get("TOCHeading") == "CAS": - return subsubsection["Information"][0]["Value"][ - "StringWithMarkup" - ][0]["String"] - except KeyError: - raise ValueError("Invalid molecule input, no Pubchem entry") - - raise ValueError("CAS number not found") - - class PatentCheck(BaseTool): name = "PatentCheck" - description = "Input SMILES, returns if molecule is patented" + description = "Input SMILES, returns if molecule is patented. You may also input several SMILES, separated by a period." def _run(self, smiles: str) -> str: """Checks if compound is patented. Give this tool only one SMILES string""" + if is_multiple_smiles(smiles): + smiles_list = split_smiles(smiles) + else: + smiles_list = [smiles] try: - r = molbloom.buy(smiles, canonicalize=True, catalog="surechembl") + output_dict = {} + for smi in smiles_list: + r = molbloom.buy(smi, canonicalize=True, catalog="surechembl") + if r: + output_dict[smi] = "Patented" + else: + output_dict[smi] = "Novel" + return str(output_dict) except: return "Invalid SMILES string" - if r: - return "Patented" - else: - return "Novel" async def _arun(self, query: str) -> str: """Use the tool asynchronously.""" @@ -359,7 +327,10 @@ def _run(self, query: str) -> str: ) else: # Get smiles of CAS number - smi = query2smiles(query) + try: + smi = query2smiles(query) + except ValueError as e: + return str(e) # Check similarity to known controlled chemicals return self.similar_control_chem_check._run(smi) @@ -386,7 +357,10 @@ def __init__( def _run(self, query: str) -> str: """This function queries the given molecule name and returns a SMILES string from the record""" """Useful to get the SMILES string of one molecule by searching the name of a molecule. Only query with one specific name.""" - smi = query2smiles(query, self.url) + try: + smi = query2smiles(query, self.url) + except ValueError as e: + return str(e) # check if smiles is controlled msg = "Note: " + self.ControlChemCheck._run(smi) if "high similarity" in msg or "appears" in msg: @@ -422,9 +396,15 @@ def _run(self, query: str) -> str: smiles = None if is_smiles(query): smiles = query - cas = query2cas(query, self.url_cid, self.url_data) + try: + cas = query2cas(query, self.url_cid, self.url_data) + except ValueError as e: + return str(e) if smiles is None: - smiles = query2smiles(query, None) + try: + smiles = query2smiles(cas, None) + except ValueError as e: + return str(e) # great now check if smiles is controlled msg = self.ControlChemCheck._run(smiles) if "high similarity" in msg or "appears" in msg: diff --git a/chemcrow/utils.py b/chemcrow/utils.py index c25a1f2..cf2310f 100644 --- a/chemcrow/utils.py +++ b/chemcrow/utils.py @@ -1,5 +1,6 @@ import re +import requests from rdkit import Chem, DataStructs from rdkit.Chem import AllChem @@ -14,6 +15,16 @@ def is_smiles(text): return False +def is_multiple_smiles(text): + if is_smiles(text): + return "." in text + return False + + +def split_smiles(text): + return text.split(".") + + def is_cas(text): pattern = r"^\d{2,7}-\d{2}-\d$" return re.match(pattern, text) is not None @@ -46,3 +57,59 @@ def tanimoto(s1, s2): return DataStructs.TanimotoSimilarity(fp1, fp2) except (TypeError, ValueError, AttributeError): return "Error: Not a valid SMILES string" + + +def query2smiles( + query: str, + url: str = "https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{}/{}", +) -> str: + if is_smiles(query): + if not is_multiple_smiles(query): + return query + else: + raise ValueError( + "Multiple SMILES strings detected, input one molecule at a time." + ) + if url is None: + url = "https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{}/{}" + r = requests.get(url.format(query, "property/IsomericSMILES/JSON")) + # convert the response to a json object + data = r.json() + # return the SMILES string + try: + smi = data["PropertyTable"]["Properties"][0]["IsomericSMILES"] + except KeyError: + return "Could not find a molecule matching the text. One possible cause is that the input is incorrect, input one molecule at a time." + return str(Chem.CanonSmiles(largest_mol(smi))) + + +def query2cas(query: str, url_cid: str, url_data: str): + try: + mode = "name" + if is_smiles(query): + if is_multiple_smiles(query): + raise ValueError( + "Multiple SMILES strings detected, input one molecule at a time." + ) + mode = "smiles" + url_cid = url_cid.format(mode, query) + cid = requests.get(url_cid).json()["IdentifierList"]["CID"][0] + url_data = url_data.format(cid) + data = requests.get(url_data).json() + except (requests.exceptions.RequestException, KeyError): + raise ValueError("Invalid molecule input, no Pubchem entry") + + try: + for section in data["Record"]["Section"]: + if section.get("TOCHeading") == "Names and Identifiers": + for subsection in section["Section"]: + if subsection.get("TOCHeading") == "Other Identifiers": + for subsubsection in subsection["Section"]: + if subsubsection.get("TOCHeading") == "CAS": + return subsubsection["Information"][0]["Value"][ + "StringWithMarkup" + ][0]["String"] + except KeyError: + raise ValueError("Invalid molecule input, no Pubchem entry") + + raise ValueError("CAS number not found") diff --git a/tests/test_databases.py b/tests/test_databases.py index 80f0815..2eba604 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -1,7 +1,9 @@ +import ast + import pytest from chemcrow.tools.safety import PatentCheck, Query2CAS, Query2SMILES -from chemcrow.utils import canonical_smiles +from chemcrow.utils import canonical_smiles, split_smiles @pytest.fixture @@ -40,9 +42,6 @@ def choline(): return "CCCCCCCCC[NH+]1C[C@@H]([C@H]([C@@H]([C@H]1CO)O)O)O" -# Query2SMILES - - def test_q2s_iupac(single_iupac): tool = Query2SMILES() out = tool._run(single_iupac) @@ -60,8 +59,6 @@ def test_q2s_fail(molset1): out = tool._run(molset1) assert out.endswith("input one molecule at a time.") -# Query2CAS - def test_q2cas_iupac(single_iupac): tool = Query2CAS() @@ -81,13 +78,22 @@ def test_q2cas_badinp(): assert out.endswith("no Pubchem entry") or out.endswith("not found") -# PatentCheck - - def test_patentcheck(singlemol): tool = PatentCheck() patented = tool._run(singlemol) - assert patented == "Patented" + patented = ast.literal_eval(patented) + assert len(patented) == 1 + assert patented[singlemol] == "Patented" + + +def test_patentcheck_molset(molset1): + tool = PatentCheck() + patented = tool._run(molset1) + patented = ast.literal_eval(patented) + mols = split_smiles(molset1) + assert len(patented) == len(mols) + assert patented[mols[0]] == "Patented" + assert patented[mols[1]] == "Novel" def test_patentcheck_iupac(single_iupac): @@ -99,4 +105,6 @@ def test_patentcheck_iupac(single_iupac): def test_patentcheck_not(choline): tool = PatentCheck() patented = tool._run(choline) - assert patented == "Novel" + patented = ast.literal_eval(patented) + assert len(patented) == 1 + assert patented[choline] == "Novel"