Skip to content

Commit

Permalink
[WIP] pinned preprint versions of langchain and pqa (#12)
Browse files Browse the repository at this point in the history
* pinned preprint versions of langchain, open and pqa, not sure if it'll work

* added async where needed

* enabled pqa test
  • Loading branch information
SamCox822 authored Nov 8, 2023
1 parent 3b7043b commit 0712ef5
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 87 deletions.
3 changes: 3 additions & 0 deletions chemcrow/tools/rxn4chem.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,6 @@ def _run(self, reactants: str) -> str:
product = res_dict["productMolecule"]["smiles"]

return product

async def _arun(self, cas_number):
raise NotImplementedError("Async not implemented.")
6 changes: 6 additions & 0 deletions chemcrow/tools/safety.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ def _run(self, cas: str) -> str:

data = self.mol_safety.get_safety_summary(cas)
return self.llm_chain.run(" ".join(data))

async def _arun(self, cas_number):
raise NotImplementedError("Async not implemented.")



Expand All @@ -185,4 +188,7 @@ def _run(self, cas_number):
return "Molecule is explosive"
else:
return "Molecule is not known to be explosive."

async def _arun(self, cas_number):
raise NotImplementedError("Async not implemented.")

150 changes: 73 additions & 77 deletions chemcrow/tools/search.py
Original file line number Diff line number Diff line change
@@ -1,95 +1,91 @@
import os
import re
from typing import Optional

import langchain
import paperqa
import langchain
import paperscraper
from langchain import SerpAPIWrapper
from langchain.base_language import BaseLanguageModel
from langchain.chains import LLMChain
from langchain.tools import BaseTool
from pydantic import validator
from pypdf.errors import PdfReadError
from langchain.tools import BaseTool
from langchain.base_language import BaseLanguageModel


class LitSearch(BaseTool):
name = "LiteratureSearch"
description = (
"Input a specific question, returns an answer from literature search. "
"Do not mention any specific molecule names, but use more general features to formulate your questions."
def paper_search(search, pdir="query"):
try:
return paperscraper.search_papers(search, pdir=pdir)
except KeyError:
return {}

def partial(func, *args, **kwargs):
"""
This function is a workaround for the partial function error in new langchain versions.
This can be removed if langchain adds support for partial functions.
"""
def wrapped(*args_wrapped, **kwargs_wrapped):
final_args = args + args_wrapped
final_kwargs = {**kwargs, **kwargs_wrapped}
return func(*final_args, **final_kwargs)
return wrapped

def scholar2result_llm(llm, query, search=None):
"""Useful to answer questions that require technical knowledge. Ask a specific question."""

prompt = langchain.prompts.PromptTemplate(
input_variables=["question"],
template="I would like to find scholarly papers to answer this question: {question}. "
'A search query that would bring up papers that can answer this question would be: "',
)
llm: BaseLanguageModel
query_chain: Optional[LLMChain] = None
pdir: str = "query"
searches: int = 2
verbose: bool = False
docs: Optional[paperqa.Docs] = None
query_chain = langchain.chains.LLMChain(llm=llm, prompt=prompt)

@validator("query_chain", always=True)
def init_query_chain(cls, v, values):
if v is None:
search_prompt = langchain.prompts.PromptTemplate(
input_variables=["question", "count"],
template="We want to answer the following question: {question} \n"
"Provide {count} keyword searches (one search per line) "
"that will find papers to help answer the question. "
"Do not use boolean operators. "
"Make some searches broad and some narrow. "
"Do not use boolean operators or quotes.\n\n"
"1. ",
)
v = LLMChain(llm=values["llm"], prompt=search_prompt)
return v
if not os.path.isdir("./query"):
os.mkdir("query/")

@validator("pdir", always=True)
def init_pdir(cls, v):
if not os.path.isdir(v):
os.mkdir(v)
return v
if search is None:
search = query_chain.run(query)
print("\nSearch:", search)
papers = paper_search(search, pdir=f"query/{re.sub(' ', '', search)}")

def paper_search(self, search):
if len(papers) == 0:
return "Not enough papers found"
docs = paperqa.Docs(llm=llm)
not_loaded = 0
for path, data in papers.items():
try:
return paperscraper.search_papers(
search, pdir=self.pdir, batch_size=6, limit=4, verbose=False
)
except KeyError:
return {}
docs.add(path, data["citation"])
except (ValueError, FileNotFoundError, PdfReadError) as e:
not_loaded += 1

def _run(self, query: str) -> str:
if self.verbose:
print("\n\nChoosing search terms\n1. ", end="")
searches = self.query_chain.run(question=query, count=self.searches)
print("")
queries = [s for s in searches.split("\n") if len(s) > 3]
# remove 2., 3. from queries
queries = [re.sub(r"^\d+\.\s*", "", q) for q in queries]
# remove quotes
queries = [re.sub(r"\"", "", q) for q in queries]
papers = {}
for q in queries:
papers.update(self.paper_search(q))
if self.verbose:
print(f"retrieved {len(papers)} papers total")
print(f"\nFound {len(papers.items())} papers but couldn't load {not_loaded}")
return docs.query(query, length_prompt="about 100 words").answer

if len(papers) == 0:
return "Not enough papers found"
if self.docs is None:
self.docs = paperqa.Docs(
llm=self.llm, summary_llm="gpt-3.5-turbo", memory=True
)
not_loaded = 0
for path, data in papers.items():
try:
self.docs.add(path, citation=data["citation"], docname=data["key"])
except (ValueError, PdfReadError):
not_loaded += 1

if not_loaded:
print(f"\nFound {len(papers.items())} papers, couldn't load {not_loaded}")
return self.docs.query(query, length_prompt="about 100 words").answer
def web_search(keywords, search_engine="google"):
try:
return SerpAPIWrapper(
serpapi_api_key=os.getenv("SERP_API_KEY"), search_engine=search_engine
).run(keywords)
except:
return "No results, try another search"


class LitSearch(BaseTool):
name = "LiteratureSearch"
description = (
"Input a specific question, returns an answer from literature search. "
"Do not mention any specific molecule names, but use more general features to formulate your questions."
)
llm: BaseLanguageModel
def _run(self, query: str) -> str:
return scholar2result_llm(self.llm, query)
async def _arun(self, query: str) -> str:
"""Use the tool asynchronously."""
raise NotImplementedError()

raise NotImplementedError("Async not implemented")

class WebSearch(BaseTool):
name = "WebSearch"
description = (
"Input a specific question, returns an answer from web search. "
"Do not mention any specific molecule names, but use more general features to formulate your questions."
)
def _run(self, query: str) -> str:
return web_search(query)
async def _arun(self, query: str) -> str:
raise NotImplementedError("Async not implemented")
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
"ipython",
"rdkit",
"synspace",
"openai==0.27.8",
"molbloom",
"paper-qa>=3.0.0",
"paper-qa==1.1.1",
"google-search-results",
"langchain",
"langchain==0.0.234",
"nest_asyncio",
"tiktoken",
"rmrkl",
Expand Down
16 changes: 8 additions & 8 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ def questions():
return qs


#def test_litsearch(questions):
# llm = ChatOpenAI()
# searchtool = LitSearch(llm=llm)
#
# for q in questions:
# ans = searchtool(q)
# assert isinstance(ans, str)
# assert len(ans) > 0
def test_litsearch(questions):
llm = ChatOpenAI()
searchtool = LitSearch(llm=llm)

for q in questions:
ans = searchtool(q)
assert isinstance(ans, str)
assert len(ans) > 0

0 comments on commit 0712ef5

Please sign in to comment.