Skip to content
This repository has been archived by the owner on Mar 1, 2024. It is now read-only.

Commit

Permalink
Update README
Browse files Browse the repository at this point in the history
  • Loading branch information
ravi03071991 committed Feb 12, 2024
1 parent 620002f commit 335a7ac
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 28 deletions.
12 changes: 8 additions & 4 deletions llama_hub/llama_packs/corrective_rag/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Corrective-RAG Pack
# Corrective Retrieval Augmented Generation Pack

Create a query engine using completely local and private models -- `HuggingFaceH4/zephyr-7b-beta` for the LLM and `BAAI/bge-base-en-v1.5` for embeddings.
This LlamaPack implements the Corrective Retrieval Augmented Generation (CRAG) [paper](https://arxiv.org/pdf/2401.15884.pdf)

Corrective Retrieval Augmented Generation (CRAG) is a method designed to enhance the robustness of language model generation by evaluating and augmenting the relevance of retrieved documents through a lightweight evaluator and large-scale web searches, ensuring more accurate and reliable information is used in generation.

This LlamaPack uses [Tavily AI](https://app.tavily.com/home) API for web-searches. So, we recommend you to get the api-key before proceeding further.

## CLI Usage

Expand All @@ -25,12 +29,12 @@ CorrectiveRAGPack = download_llama_pack(
)

# You can use any llama-hub loader to get documents!
corrective_rag_pack = CorrectiveRAGPack(documents)
corrective_rag_pack = CorrectiveRAGPack(documents, tavily_ai_api_key)
```

From here, you can use the pack, or inspect and modify the pack in `./corrective_rag_pack`.

The `run()` function contains around logic behind Corrective RAG - [CRAG](https://arxiv.org/pdf/2401.15884.pdf) PAPER.
The `run()` function contains around logic behind Corrective Retrieval Augmented Generation - [CRAG](https://arxiv.org/pdf/2401.15884.pdf) paper.

```python
response = corrective_rag_pack.run("What did the author do growing up?", similarity_top_k=2)
Expand Down
58 changes: 35 additions & 23 deletions llama_hub/llama_packs/corrective_rag/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,55 +2,66 @@
from typing import Any, Dict, List
from prompts import DEFAULT_TRANSFORM_QUERY_TEMPLATE, DEFAULT_RELEVANCY_PROMPT_TEMPLATE

from llama_index import ServiceContext, VectorStoreIndex, SummaryIndex
from llama_index import VectorStoreIndex, SummaryIndex
from llama_index.llama_pack.base import BaseLlamaPack
from llama_index.llms import OpenAI
from llama_index.schema import Document, NodeWithScore
from llama_index.query_pipeline.query import QueryPipeline
from llama_hub.tools.tavily_research.base import TavilyToolSpec


class CorrectiveRAGPack(BaseLlamaPack):
def __init__(self,
documents: List[Document],
tavily_ai_apikey: str) -> None:
def __init__(self, documents: List[Document], tavily_ai_apikey: str) -> None:
"""Init params."""

llm = OpenAI(model='gpt-4')
self.relevancy_pipeline = QueryPipeline(chain=[DEFAULT_RELEVANCY_PROMPT_TEMPLATE,
llm])
self.transform_query_pipeline = QueryPipeline(chain=[DEFAULT_TRANSFORM_QUERY_TEMPLATE,
llm])

llm = OpenAI(model="gpt-4")
self.relevancy_pipeline = QueryPipeline(
chain=[DEFAULT_RELEVANCY_PROMPT_TEMPLATE, llm]
)
self.transform_query_pipeline = QueryPipeline(
chain=[DEFAULT_TRANSFORM_QUERY_TEMPLATE, llm]
)

self.llm = llm
self.index = VectorStoreIndex.from_documents(documents)
self.tavily_tool = TavilyToolSpec(api_key=tavily_ai_apikey)

def get_modules(self) -> Dict[str, Any]:
"""Get modules."""
return {"llm": self.llm, "index": self.index}

def retrieve_nodes(self, query_str: str, **kwargs: Any) -> List[NodeWithScore]:
"""Retrieve the relevant nodes for the query"""
retriever = self.index.as_retriever(**kwargs)
return retriever.retrieve(query_str)

def evaluate_relevancy(self, retrieved_nodes: List[Document], query_str: str) -> List[str]:
def evaluate_relevancy(
self, retrieved_nodes: List[Document], query_str: str
) -> List[str]:
"""Evaluate relevancy of retrieved documents with the query"""
relevancy_results = []
for node in retrieved_nodes:
relevancy = self.relevancy_pipeline.run(context_str=node.text, query_str=query_str)
relevancy = self.relevancy_pipeline.run(
context_str=node.text, query_str=query_str
)
relevancy_results.append(relevancy.message.content.lower().strip())
return relevancy_results

def extract_relevant_texts(self, retrieved_nodes: List[NodeWithScore], relevancy_results: List[str]) -> str:
def extract_relevant_texts(
self, retrieved_nodes: List[NodeWithScore], relevancy_results: List[str]
) -> str:
"""Extract relevant texts from retrieved documents"""
relevant_texts = [retrieved_nodes[i].text for i, result in enumerate(relevancy_results) if result == 'yes']
return '\n'.join(relevant_texts)
relevant_texts = [
retrieved_nodes[i].text
for i, result in enumerate(relevancy_results)
if result == "yes"
]
return "\n".join(relevant_texts)

def search_with_transformed_query(self, query_str: str) -> str:
"""Search the transformed query with Tavily API"""
search_results = self.tavily_tool.search(query_str, max_results=2)
return '\n'.join([result.text for result in search_results])
return "\n".join([result.text for result in search_results])

def get_result(self, relevant_text: str, search_text: str, query_str: str) -> Any:
"""Get result with relevant text"""
Expand All @@ -69,13 +80,15 @@ def run(self, query_str: str, **kwargs: Any) -> Any:

# Extract texts from documents that are deemed relevant based on the evaluation.
relevant_text = self.extract_relevant_texts(retrieved_nodes, relevancy_results)

# Initialize search_text variable to handle cases where it might not get defined.
search_text = ''
search_text = ""

# If any document is found irrelevant, transform the query string for better search results.
if "no" in relevancy_results:
transformed_query_str = self.transform_query_pipeline.run(query_str=query_str).message.content
transformed_query_str = self.transform_query_pipeline.run(
query_str=query_str
).message.content

# Conduct a search with the transformed query string and collect the results.
search_text = self.search_with_transformed_query(transformed_query_str)
Expand All @@ -85,5 +98,4 @@ def run(self, query_str: str, **kwargs: Any) -> Any:
if search_text:
return self.get_result(relevant_text, search_text, query_str)
else:
return self.get_result(relevant_text, '', query_str)

return self.get_result(relevant_text, "", query_str)
3 changes: 2 additions & 1 deletion llama_hub/llama_packs/corrective_rag/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,5 @@
{query_str}
\n ------- \n
Your goal is to rephrase or enhance this query to improve its search performance. Ensure the revised query is concise and directly aligned with the intended search objective. \n
Respond with the optimized query only:""")
Respond with the optimized query only:"""
)

0 comments on commit 335a7ac

Please sign in to comment.