From 59e63bdcd879fdf180e7b70705f9e2d64b03c15c Mon Sep 17 00:00:00 2001 From: Andres Date: Thu, 14 Mar 2024 13:29:41 +0100 Subject: [PATCH] adding semanticscholar_apikey to prevent exceeding rate limit --- chemcrow/agents/tools.py | 9 ++++- .../frontend/streamlit_callback_handler.py | 4 +- chemcrow/tools/search.py | 39 +++++++++++++------ 3 files changed, 38 insertions(+), 14 deletions(-) diff --git a/chemcrow/agents/tools.py b/chemcrow/agents/tools.py index d607a0a..6574cd8 100644 --- a/chemcrow/agents/tools.py +++ b/chemcrow/agents/tools.py @@ -13,6 +13,9 @@ def make_tools(llm: BaseLanguageModel, api_keys: dict = {}, verbose=True): chemspace_api_key = api_keys.get("CHEMSPACE_API_KEY") or os.getenv( "CHEMSPACE_API_KEY" ) + semantic_scholar_api_key = api_keys.get("SEMANTIC_SCHOLAR_API_KEY") or os.getenv( + "SEMANTIC_SCHOLAR_API_KEY" + ) all_tools = agents.load_tools( [ @@ -34,8 +37,12 @@ def make_tools(llm: BaseLanguageModel, api_keys: dict = {}, verbose=True): ExplosiveCheck(), ControlChemCheck(), SimilarControlChemCheck(), - Scholar2ResultLLM(llm=llm, api_key=openai_api_key), SafetySummary(llm=llm), + Scholar2ResultLLM( + llm=llm, + openai_api_key=openai_api_key, + semantic_scholar_api_key=semantic_scholar_api_key + ), ] if chemspace_api_key: all_tools += [GetMoleculePrice(chemspace_api_key)] diff --git a/chemcrow/frontend/streamlit_callback_handler.py b/chemcrow/frontend/streamlit_callback_handler.py index a3060c0..c05bc15 100644 --- a/chemcrow/frontend/streamlit_callback_handler.py +++ b/chemcrow/frontend/streamlit_callback_handler.py @@ -79,9 +79,9 @@ def on_tool_start( ) # Display note of potential long time - if serialized["name"] == "ReactionRetrosynthesis": + if serialized["name"] == "ReactionRetrosynthesis" or serialized["name"] == "LiteratureSearch": self._container.markdown( - f"‼️ Note: This tool can take up to 5 minutes to complete execution ‼️", + f"‼️ Note: This tool can take some time to complete execution ‼️", unsafe_allow_html=True, ) diff --git a/chemcrow/tools/search.py b/chemcrow/tools/search.py index 848bfde..527600f 100644 --- a/chemcrow/tools/search.py +++ b/chemcrow/tools/search.py @@ -14,14 +14,18 @@ from chemcrow.utils import is_multiple_smiles, split_smiles -def paper_scraper(search: str, pdir: str = "query") -> dict: +def paper_scraper(search: str, pdir: str = "query", semantic_scholar_api_key: str = None) -> dict: try: - return paperscraper.search_papers(search, pdir=pdir) + return paperscraper.search_papers( + search, + pdir=pdir, + semantic_scholar_api_key=semantic_scholar_api_key, + ) except KeyError: return {} -def paper_search(llm, query): +def paper_search(llm, query, semantic_scholar_api_key=None): prompt = langchain.prompts.PromptTemplate( input_variables=["question"], template=""" @@ -37,14 +41,14 @@ def paper_search(llm, query): os.mkdir("query/") search = query_chain.run(query) print("\nSearch:", search) - papers = paper_scraper(search, pdir=f"query/{re.sub(' ', '', search)}") + papers = paper_scraper(search, pdir=f"query/{re.sub(' ', '', search)}", semantic_scholar_api_key=semantic_scholar_api_key) return papers -def scholar2result_llm(llm, query, k=5, max_sources=2, openai_api_key=None): +def scholar2result_llm(llm, query, k=5, max_sources=2, openai_api_key=None, semantic_scholar_api_key=None): """Useful to answer questions that require technical knowledge. Ask a specific question.""" - papers = paper_search(llm, query) + papers = paper_search(llm, query, semantic_scholar_api_key=semantic_scholar_api_key) if len(papers) == 0: return "Not enough papers found" docs = paperqa.Docs( @@ -59,7 +63,11 @@ def scholar2result_llm(llm, query, k=5, max_sources=2, openai_api_key=None): except (ValueError, FileNotFoundError, PdfReadError): not_loaded += 1 - print(f"\nFound {len(papers.items())} papers but couldn't load {not_loaded}") + if not_loaded > 0: + print(f"\nFound {len(papers.items())} papers but couldn't load {not_loaded}.") + else: + print(f"\nFound {len(papers.items())} papers and loaded all of them.") + answer = docs.query(query, k=k, max_sources=max_sources).formatted_answer return answer @@ -71,15 +79,24 @@ class Scholar2ResultLLM(BaseTool): "knowledge. Ask a specific question." ) llm: BaseLanguageModel = None - api_key: str = None + openai_api_key: str = None + semantic_scholar_api_key: str = None + - def __init__(self, llm, api_key): + def __init__(self, llm, openai_api_key, semantic_scholar_api_key): super().__init__() self.llm = llm - self.api_key = api_key + # api keys + self.openai_api_key = openai_api_key + self.semantic_scholar_api_key = semantic_scholar_api_key def _run(self, query) -> str: - return scholar2result_llm(self.llm, query, openai_api_key=self.api_key) + return scholar2result_llm( + self.llm, + query, + openai_api_key=self.openai_api_key, + semantic_scholar_api_key=self.semantic_scholar_api_key + ) async def _arun(self, query) -> str: """Use the tool asynchronously."""