Skip to content

Commit

Permalink
adding semanticscholar_apikey to prevent exceeding rate limit
Browse files Browse the repository at this point in the history
  • Loading branch information
doncamilom committed Mar 14, 2024
1 parent 4a0b137 commit 59e63bd
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 14 deletions.
9 changes: 8 additions & 1 deletion chemcrow/agents/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ def make_tools(llm: BaseLanguageModel, api_keys: dict = {}, verbose=True):
chemspace_api_key = api_keys.get("CHEMSPACE_API_KEY") or os.getenv(
"CHEMSPACE_API_KEY"
)
semantic_scholar_api_key = api_keys.get("SEMANTIC_SCHOLAR_API_KEY") or os.getenv(
"SEMANTIC_SCHOLAR_API_KEY"
)

all_tools = agents.load_tools(
[
Expand All @@ -34,8 +37,12 @@ def make_tools(llm: BaseLanguageModel, api_keys: dict = {}, verbose=True):
ExplosiveCheck(),
ControlChemCheck(),
SimilarControlChemCheck(),
Scholar2ResultLLM(llm=llm, api_key=openai_api_key),
SafetySummary(llm=llm),
Scholar2ResultLLM(
llm=llm,
openai_api_key=openai_api_key,
semantic_scholar_api_key=semantic_scholar_api_key
),
]
if chemspace_api_key:
all_tools += [GetMoleculePrice(chemspace_api_key)]
Expand Down
4 changes: 2 additions & 2 deletions chemcrow/frontend/streamlit_callback_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ def on_tool_start(
)

# Display note of potential long time
if serialized["name"] == "ReactionRetrosynthesis":
if serialized["name"] == "ReactionRetrosynthesis" or serialized["name"] == "LiteratureSearch":
self._container.markdown(
f"‼️ Note: This tool can take up to 5 minutes to complete execution ‼️",
f"‼️ Note: This tool can take some time to complete execution ‼️",
unsafe_allow_html=True,
)

Expand Down
39 changes: 28 additions & 11 deletions chemcrow/tools/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,18 @@
from chemcrow.utils import is_multiple_smiles, split_smiles


def paper_scraper(search: str, pdir: str = "query") -> dict:
def paper_scraper(search: str, pdir: str = "query", semantic_scholar_api_key: str = None) -> dict:
try:
return paperscraper.search_papers(search, pdir=pdir)
return paperscraper.search_papers(
search,
pdir=pdir,
semantic_scholar_api_key=semantic_scholar_api_key,
)
except KeyError:
return {}


def paper_search(llm, query):
def paper_search(llm, query, semantic_scholar_api_key=None):
prompt = langchain.prompts.PromptTemplate(
input_variables=["question"],
template="""
Expand All @@ -37,14 +41,14 @@ def paper_search(llm, query):
os.mkdir("query/")
search = query_chain.run(query)
print("\nSearch:", search)
papers = paper_scraper(search, pdir=f"query/{re.sub(' ', '', search)}")
papers = paper_scraper(search, pdir=f"query/{re.sub(' ', '', search)}", semantic_scholar_api_key=semantic_scholar_api_key)
return papers


def scholar2result_llm(llm, query, k=5, max_sources=2, openai_api_key=None):
def scholar2result_llm(llm, query, k=5, max_sources=2, openai_api_key=None, semantic_scholar_api_key=None):
"""Useful to answer questions that require
technical knowledge. Ask a specific question."""
papers = paper_search(llm, query)
papers = paper_search(llm, query, semantic_scholar_api_key=semantic_scholar_api_key)
if len(papers) == 0:
return "Not enough papers found"
docs = paperqa.Docs(
Expand All @@ -59,7 +63,11 @@ def scholar2result_llm(llm, query, k=5, max_sources=2, openai_api_key=None):
except (ValueError, FileNotFoundError, PdfReadError):
not_loaded += 1

print(f"\nFound {len(papers.items())} papers but couldn't load {not_loaded}")
if not_loaded > 0:
print(f"\nFound {len(papers.items())} papers but couldn't load {not_loaded}.")
else:
print(f"\nFound {len(papers.items())} papers and loaded all of them.")

answer = docs.query(query, k=k, max_sources=max_sources).formatted_answer
return answer

Expand All @@ -71,15 +79,24 @@ class Scholar2ResultLLM(BaseTool):
"knowledge. Ask a specific question."
)
llm: BaseLanguageModel = None
api_key: str = None
openai_api_key: str = None
semantic_scholar_api_key: str = None


def __init__(self, llm, api_key):
def __init__(self, llm, openai_api_key, semantic_scholar_api_key):
super().__init__()
self.llm = llm
self.api_key = api_key
# api keys
self.openai_api_key = openai_api_key
self.semantic_scholar_api_key = semantic_scholar_api_key

def _run(self, query) -> str:
return scholar2result_llm(self.llm, query, openai_api_key=self.api_key)
return scholar2result_llm(
self.llm,
query,
openai_api_key=self.openai_api_key,
semantic_scholar_api_key=self.semantic_scholar_api_key
)

async def _arun(self, query) -> str:
"""Use the tool asynchronously."""
Expand Down

0 comments on commit 59e63bd

Please sign in to comment.