-
Notifications
You must be signed in to change notification settings - Fork 106
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[WIP] pinned preprint versions of langchain and pqa (#12)
* pinned preprint versions of langchain, open and pqa, not sure if it'll work * added async where needed * enabled pqa test
- Loading branch information
Showing
5 changed files
with
93 additions
and
87 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,95 +1,91 @@ | ||
import os | ||
import re | ||
from typing import Optional | ||
|
||
import langchain | ||
import paperqa | ||
import langchain | ||
import paperscraper | ||
from langchain import SerpAPIWrapper | ||
from langchain.base_language import BaseLanguageModel | ||
from langchain.chains import LLMChain | ||
from langchain.tools import BaseTool | ||
from pydantic import validator | ||
from pypdf.errors import PdfReadError | ||
from langchain.tools import BaseTool | ||
from langchain.base_language import BaseLanguageModel | ||
|
||
|
||
class LitSearch(BaseTool): | ||
name = "LiteratureSearch" | ||
description = ( | ||
"Input a specific question, returns an answer from literature search. " | ||
"Do not mention any specific molecule names, but use more general features to formulate your questions." | ||
def paper_search(search, pdir="query"): | ||
try: | ||
return paperscraper.search_papers(search, pdir=pdir) | ||
except KeyError: | ||
return {} | ||
|
||
def partial(func, *args, **kwargs): | ||
""" | ||
This function is a workaround for the partial function error in new langchain versions. | ||
This can be removed if langchain adds support for partial functions. | ||
""" | ||
def wrapped(*args_wrapped, **kwargs_wrapped): | ||
final_args = args + args_wrapped | ||
final_kwargs = {**kwargs, **kwargs_wrapped} | ||
return func(*final_args, **final_kwargs) | ||
return wrapped | ||
|
||
def scholar2result_llm(llm, query, search=None): | ||
"""Useful to answer questions that require technical knowledge. Ask a specific question.""" | ||
|
||
prompt = langchain.prompts.PromptTemplate( | ||
input_variables=["question"], | ||
template="I would like to find scholarly papers to answer this question: {question}. " | ||
'A search query that would bring up papers that can answer this question would be: "', | ||
) | ||
llm: BaseLanguageModel | ||
query_chain: Optional[LLMChain] = None | ||
pdir: str = "query" | ||
searches: int = 2 | ||
verbose: bool = False | ||
docs: Optional[paperqa.Docs] = None | ||
query_chain = langchain.chains.LLMChain(llm=llm, prompt=prompt) | ||
|
||
@validator("query_chain", always=True) | ||
def init_query_chain(cls, v, values): | ||
if v is None: | ||
search_prompt = langchain.prompts.PromptTemplate( | ||
input_variables=["question", "count"], | ||
template="We want to answer the following question: {question} \n" | ||
"Provide {count} keyword searches (one search per line) " | ||
"that will find papers to help answer the question. " | ||
"Do not use boolean operators. " | ||
"Make some searches broad and some narrow. " | ||
"Do not use boolean operators or quotes.\n\n" | ||
"1. ", | ||
) | ||
v = LLMChain(llm=values["llm"], prompt=search_prompt) | ||
return v | ||
if not os.path.isdir("./query"): | ||
os.mkdir("query/") | ||
|
||
@validator("pdir", always=True) | ||
def init_pdir(cls, v): | ||
if not os.path.isdir(v): | ||
os.mkdir(v) | ||
return v | ||
if search is None: | ||
search = query_chain.run(query) | ||
print("\nSearch:", search) | ||
papers = paper_search(search, pdir=f"query/{re.sub(' ', '', search)}") | ||
|
||
def paper_search(self, search): | ||
if len(papers) == 0: | ||
return "Not enough papers found" | ||
docs = paperqa.Docs(llm=llm) | ||
not_loaded = 0 | ||
for path, data in papers.items(): | ||
try: | ||
return paperscraper.search_papers( | ||
search, pdir=self.pdir, batch_size=6, limit=4, verbose=False | ||
) | ||
except KeyError: | ||
return {} | ||
docs.add(path, data["citation"]) | ||
except (ValueError, FileNotFoundError, PdfReadError) as e: | ||
not_loaded += 1 | ||
|
||
def _run(self, query: str) -> str: | ||
if self.verbose: | ||
print("\n\nChoosing search terms\n1. ", end="") | ||
searches = self.query_chain.run(question=query, count=self.searches) | ||
print("") | ||
queries = [s for s in searches.split("\n") if len(s) > 3] | ||
# remove 2., 3. from queries | ||
queries = [re.sub(r"^\d+\.\s*", "", q) for q in queries] | ||
# remove quotes | ||
queries = [re.sub(r"\"", "", q) for q in queries] | ||
papers = {} | ||
for q in queries: | ||
papers.update(self.paper_search(q)) | ||
if self.verbose: | ||
print(f"retrieved {len(papers)} papers total") | ||
print(f"\nFound {len(papers.items())} papers but couldn't load {not_loaded}") | ||
return docs.query(query, length_prompt="about 100 words").answer | ||
|
||
if len(papers) == 0: | ||
return "Not enough papers found" | ||
if self.docs is None: | ||
self.docs = paperqa.Docs( | ||
llm=self.llm, summary_llm="gpt-3.5-turbo", memory=True | ||
) | ||
not_loaded = 0 | ||
for path, data in papers.items(): | ||
try: | ||
self.docs.add(path, citation=data["citation"], docname=data["key"]) | ||
except (ValueError, PdfReadError): | ||
not_loaded += 1 | ||
|
||
if not_loaded: | ||
print(f"\nFound {len(papers.items())} papers, couldn't load {not_loaded}") | ||
return self.docs.query(query, length_prompt="about 100 words").answer | ||
def web_search(keywords, search_engine="google"): | ||
try: | ||
return SerpAPIWrapper( | ||
serpapi_api_key=os.getenv("SERP_API_KEY"), search_engine=search_engine | ||
).run(keywords) | ||
except: | ||
return "No results, try another search" | ||
|
||
|
||
class LitSearch(BaseTool): | ||
name = "LiteratureSearch" | ||
description = ( | ||
"Input a specific question, returns an answer from literature search. " | ||
"Do not mention any specific molecule names, but use more general features to formulate your questions." | ||
) | ||
llm: BaseLanguageModel | ||
def _run(self, query: str) -> str: | ||
return scholar2result_llm(self.llm, query) | ||
async def _arun(self, query: str) -> str: | ||
"""Use the tool asynchronously.""" | ||
raise NotImplementedError() | ||
|
||
raise NotImplementedError("Async not implemented") | ||
|
||
class WebSearch(BaseTool): | ||
name = "WebSearch" | ||
description = ( | ||
"Input a specific question, returns an answer from web search. " | ||
"Do not mention any specific molecule names, but use more general features to formulate your questions." | ||
) | ||
def _run(self, query: str) -> str: | ||
return web_search(query) | ||
async def _arun(self, query: str) -> str: | ||
raise NotImplementedError("Async not implemented") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters