diff --git a/chemcrow/tools/rxn4chem.py b/chemcrow/tools/rxn4chem.py index 7927684..6e5bad1 100644 --- a/chemcrow/tools/rxn4chem.py +++ b/chemcrow/tools/rxn4chem.py @@ -53,3 +53,6 @@ def _run(self, reactants: str) -> str: product = res_dict["productMolecule"]["smiles"] return product + + async def _arun(self, cas_number): + raise NotImplementedError("Async not implemented.") diff --git a/chemcrow/tools/safety.py b/chemcrow/tools/safety.py index 27bad77..a772346 100644 --- a/chemcrow/tools/safety.py +++ b/chemcrow/tools/safety.py @@ -160,6 +160,9 @@ def _run(self, cas: str) -> str: data = self.mol_safety.get_safety_summary(cas) return self.llm_chain.run(" ".join(data)) + + async def _arun(self, cas_number): + raise NotImplementedError("Async not implemented.") @@ -185,4 +188,7 @@ def _run(self, cas_number): return "Molecule is explosive" else: return "Molecule is not known to be explosive." + + async def _arun(self, cas_number): + raise NotImplementedError("Async not implemented.") diff --git a/chemcrow/tools/search.py b/chemcrow/tools/search.py index 022d0c1..50dc758 100644 --- a/chemcrow/tools/search.py +++ b/chemcrow/tools/search.py @@ -1,95 +1,91 @@ import os import re -from typing import Optional - -import langchain import paperqa +import langchain import paperscraper from langchain import SerpAPIWrapper -from langchain.base_language import BaseLanguageModel -from langchain.chains import LLMChain -from langchain.tools import BaseTool -from pydantic import validator from pypdf.errors import PdfReadError +from langchain.tools import BaseTool +from langchain.base_language import BaseLanguageModel -class LitSearch(BaseTool): - name = "LiteratureSearch" - description = ( - "Input a specific question, returns an answer from literature search. " - "Do not mention any specific molecule names, but use more general features to formulate your questions." +def paper_search(search, pdir="query"): + try: + return paperscraper.search_papers(search, pdir=pdir) + except KeyError: + return {} + +def partial(func, *args, **kwargs): + """ + This function is a workaround for the partial function error in new langchain versions. + This can be removed if langchain adds support for partial functions. + """ + def wrapped(*args_wrapped, **kwargs_wrapped): + final_args = args + args_wrapped + final_kwargs = {**kwargs, **kwargs_wrapped} + return func(*final_args, **final_kwargs) + return wrapped + +def scholar2result_llm(llm, query, search=None): + """Useful to answer questions that require technical knowledge. Ask a specific question.""" + + prompt = langchain.prompts.PromptTemplate( + input_variables=["question"], + template="I would like to find scholarly papers to answer this question: {question}. " + 'A search query that would bring up papers that can answer this question would be: "', ) - llm: BaseLanguageModel - query_chain: Optional[LLMChain] = None - pdir: str = "query" - searches: int = 2 - verbose: bool = False - docs: Optional[paperqa.Docs] = None + query_chain = langchain.chains.LLMChain(llm=llm, prompt=prompt) - @validator("query_chain", always=True) - def init_query_chain(cls, v, values): - if v is None: - search_prompt = langchain.prompts.PromptTemplate( - input_variables=["question", "count"], - template="We want to answer the following question: {question} \n" - "Provide {count} keyword searches (one search per line) " - "that will find papers to help answer the question. " - "Do not use boolean operators. " - "Make some searches broad and some narrow. " - "Do not use boolean operators or quotes.\n\n" - "1. ", - ) - v = LLMChain(llm=values["llm"], prompt=search_prompt) - return v + if not os.path.isdir("./query"): + os.mkdir("query/") - @validator("pdir", always=True) - def init_pdir(cls, v): - if not os.path.isdir(v): - os.mkdir(v) - return v + if search is None: + search = query_chain.run(query) + print("\nSearch:", search) + papers = paper_search(search, pdir=f"query/{re.sub(' ', '', search)}") - def paper_search(self, search): + if len(papers) == 0: + return "Not enough papers found" + docs = paperqa.Docs(llm=llm) + not_loaded = 0 + for path, data in papers.items(): try: - return paperscraper.search_papers( - search, pdir=self.pdir, batch_size=6, limit=4, verbose=False - ) - except KeyError: - return {} + docs.add(path, data["citation"]) + except (ValueError, FileNotFoundError, PdfReadError) as e: + not_loaded += 1 - def _run(self, query: str) -> str: - if self.verbose: - print("\n\nChoosing search terms\n1. ", end="") - searches = self.query_chain.run(question=query, count=self.searches) - print("") - queries = [s for s in searches.split("\n") if len(s) > 3] - # remove 2., 3. from queries - queries = [re.sub(r"^\d+\.\s*", "", q) for q in queries] - # remove quotes - queries = [re.sub(r"\"", "", q) for q in queries] - papers = {} - for q in queries: - papers.update(self.paper_search(q)) - if self.verbose: - print(f"retrieved {len(papers)} papers total") + print(f"\nFound {len(papers.items())} papers but couldn't load {not_loaded}") + return docs.query(query, length_prompt="about 100 words").answer - if len(papers) == 0: - return "Not enough papers found" - if self.docs is None: - self.docs = paperqa.Docs( - llm=self.llm, summary_llm="gpt-3.5-turbo", memory=True - ) - not_loaded = 0 - for path, data in papers.items(): - try: - self.docs.add(path, citation=data["citation"], docname=data["key"]) - except (ValueError, PdfReadError): - not_loaded += 1 - if not_loaded: - print(f"\nFound {len(papers.items())} papers, couldn't load {not_loaded}") - return self.docs.query(query, length_prompt="about 100 words").answer +def web_search(keywords, search_engine="google"): + try: + return SerpAPIWrapper( + serpapi_api_key=os.getenv("SERP_API_KEY"), search_engine=search_engine + ).run(keywords) + except: + return "No results, try another search" + +class LitSearch(BaseTool): + name = "LiteratureSearch" + description = ( + "Input a specific question, returns an answer from literature search. " + "Do not mention any specific molecule names, but use more general features to formulate your questions." + ) + llm: BaseLanguageModel + def _run(self, query: str) -> str: + return scholar2result_llm(self.llm, query) async def _arun(self, query: str) -> str: - """Use the tool asynchronously.""" - raise NotImplementedError() - + raise NotImplementedError("Async not implemented") + +class WebSearch(BaseTool): + name = "WebSearch" + description = ( + "Input a specific question, returns an answer from web search. " + "Do not mention any specific molecule names, but use more general features to formulate your questions." + ) + def _run(self, query: str) -> str: + return web_search(query) + async def _arun(self, query: str) -> str: + raise NotImplementedError("Async not implemented") \ No newline at end of file diff --git a/setup.py b/setup.py index 0b9d80b..73f5fb1 100644 --- a/setup.py +++ b/setup.py @@ -20,10 +20,11 @@ "ipython", "rdkit", "synspace", + "openai==0.27.8", "molbloom", - "paper-qa>=3.0.0", + "paper-qa==1.1.1", "google-search-results", - "langchain", + "langchain==0.0.234", "nest_asyncio", "tiktoken", "rmrkl", diff --git a/tests/test_search.py b/tests/test_search.py index f6c6df7..6fcafc2 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -15,12 +15,12 @@ def questions(): return qs -#def test_litsearch(questions): -# llm = ChatOpenAI() -# searchtool = LitSearch(llm=llm) -# -# for q in questions: -# ans = searchtool(q) -# assert isinstance(ans, str) -# assert len(ans) > 0 +def test_litsearch(questions): + llm = ChatOpenAI() + searchtool = LitSearch(llm=llm) + + for q in questions: + ans = searchtool(q) + assert isinstance(ans, str) + assert len(ans) > 0