Skip to content

Commit

Permalink
update literature search
Browse files Browse the repository at this point in the history
  • Loading branch information
qcampbel committed Oct 1, 2024
1 parent f4cecf4 commit f72b46d
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 79 deletions.
3 changes: 2 additions & 1 deletion mdagent/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,15 @@ def __init__(
uploaded_files=[], # user input files to add to path registry
run_id="",
use_memory=False,
paper_dir="ckpt/paper_collection", # papers for pqa, relative path within repo
):
self.llm = _make_llm(model, temp, streaming)
if tools_model is None:
tools_model = model
self.tools_llm = _make_llm(tools_model, temp, streaming)

self.use_memory = use_memory
self.path_registry = PathRegistry.get_instance(ckpt_dir=ckpt_dir)
self.path_registry = PathRegistry.get_instance(ckpt_dir, paper_dir)
self.ckpt_dir = self.path_registry.ckpt_dir
self.memory = MemoryManager(self.path_registry, self.tools_llm, run_id=run_id)
self.run_id = self.memory.run_id
Expand Down
102 changes: 30 additions & 72 deletions mdagent/tools/base_tools/util_tools/search_tools.py
Original file line number Diff line number Diff line change
@@ -1,90 +1,44 @@
import logging
import os
import re
from typing import Optional

import langchain
import nest_asyncio
import paperqa
import paperscraper
from langchain.base_language import BaseLanguageModel
from langchain.tools import BaseTool
from langchain_core.output_parsers import StrOutputParser
from pypdf.errors import PdfReadError

from mdagent.utils import PathRegistry


def configure_logging(path):
# to log all runtime errors from paperscraper, which can be VERY noisy
log_file = os.path.join(path, "scraping_errors.log")
logging.basicConfig(
filename=log_file,
level=logging.ERROR,
format="%(asctime)s:%(levelname)s:%(message)s",
)


def paper_scraper(search: str, pdir: str = "query") -> dict:
try:
return paperscraper.search_papers(search, pdir=pdir)
except KeyError:
return {}


def paper_search(llm, query, path_registry):
prompt = langchain.prompts.PromptTemplate(
input_variables=["question"],
template="""
I would like to find scholarly papers to answer
this question: {question}. Your response must be at
most 10 words long.
'A search query that would bring up papers that can answer
this question would be: '""",
)

path = f"{path_registry.ckpt_files}/query"
query_chain = prompt | llm | StrOutputParser()
if not os.path.isdir(path):
os.mkdir(path)
configure_logging(path)
search = query_chain.invoke(query)
print("\nSearch:", search)
papers = paper_scraper(search, pdir=f"{path}/{re.sub(' ', '', search)}")
return papers


def scholar2result_llm(llm, query, path_registry, k=5, max_sources=2):
"""Useful to answer questions that require
technical knowledge. Ask a specific question."""
if llm.model_name.startswith("gpt"):
docs = paperqa.Docs(llm=llm.model_name)
def scholar2result_llm(llm, query, path_registry):
paper_directory = path_registry.ckpt_papers
if paper_directory is None:
raise ValueError("The 'paper_dir' is None and wasn't set from the start.")
print("Paper Directory", paper_directory)
llm_name = llm.model_name
if llm_name.startswith("gpt") or llm_name.startswith("claude"):
settings = paperqa.Settings(
llm=llm_name,
summary_llm=llm_name,
temperature=llm.temperature,
paper_directory=paper_directory,
)
else:
docs = paperqa.Docs() # uses default gpt model in paperqa

papers = paper_search(llm, query, path_registry)
if len(papers) == 0:
return "Failed. Not enough papers found"
not_loaded = 0
for path, data in papers.items():
try:
docs.add(path, data["citation"])
except (ValueError, FileNotFoundError, PdfReadError):
not_loaded += 1

print(
f"\nFound {len(papers)} papers"
+ (f" but couldn't load {not_loaded}" if not_loaded > 0 else "")
)
answer = docs.query(query, k=k, max_sources=max_sources).formatted_answer
return "Succeeded. " + answer
settings = paperqa.Settings(
temperature=llm.temperature, # uses default gpt model in paperqa
paper_directory=paper_directory,
)
response = paperqa.ask(query, settings=settings)
answer = response.answer.formatted_answer
if "I cannot answer." in answer:
answer += f" Check to ensure there's papers in {paper_directory}"
print(answer)
return answer


class Scholar2ResultLLM(BaseTool):
name = "LiteratureSearch"
description = (
"Useful to answer questions that require technical "
"knowledge. Ask a specific question."
"Useful to answer questions that may be found in literature. "
"Ask a specific question as the input."
)
llm: BaseLanguageModel = None
path_registry: Optional[PathRegistry]
Expand All @@ -96,7 +50,11 @@ def __init__(self, llm, path_registry):

def _run(self, query) -> str:
nest_asyncio.apply()
return scholar2result_llm(self.llm, query, self.path_registry)
try:
return scholar2result_llm(self.llm, query, self.path_registry)
except Exception as e:
print(e)
return f"Failed. {type(e).__name__}: {e}"

async def _arun(self, query) -> str:
"""Use the tool asynchronously."""
Expand Down
15 changes: 10 additions & 5 deletions mdagent/utils/path_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,24 @@ class PathRegistry:

@classmethod
# set ckpt_dir to None by default
def get_instance(cls, ckpt_dir=None):
def get_instance(cls, ckpt_dir=None, paper_dir=None):
# todo: use same ckpt if run_id is given
if not cls.instance or ckpt_dir is not None:
cls.instance = cls(ckpt_dir)
cls.instance = cls(ckpt_dir, paper_dir)
return cls.instance

def __init__(self, ckpt_dir: str = "ckpt"):
self._set_ckpt(ckpt_dir)
def __init__(
self, ckpt_dir: str = "ckpt", paper_dir: str = "ckpt/paper_collection"
):
self._set_ckpt(ckpt_dir, paper_dir)
self._make_all_dirs()
self._init_path_registry()

def _set_ckpt(self, ckpt: str):
def _set_ckpt(self, ckpt: str, paper_dir: str):
self.ckpt_dir = self.set_ckpt.set_ckpt_subdir(ckpt_dir=ckpt)
if paper_dir is not None:
paper_dir = os.path.join(self.set_ckpt.find_root_dir(), paper_dir)
self.ckpt_papers = paper_dir

def _make_all_dirs(self):
self.json_file_path = os.path.join(self.ckpt_dir, "paths_registry.json")
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"matplotlib",
"nbformat",
"openai",
"paper-qa==4.0.0rc8 ",
"paper-qa==5.0.6",
"paper-scraper @ git+https://github.com/blackadad/paper-scraper.git",
"pandas",
"pydantic>=2.6",
Expand Down

0 comments on commit f72b46d

Please sign in to comment.