Skip to content
This repository has been archived by the owner on Mar 1, 2024. It is now read-only.

Commit

Permalink
Support simple self-rag short form
Browse files Browse the repository at this point in the history
  • Loading branch information
MarouaneMaatouk committed Jan 28, 2024
1 parent e382a46 commit 801302d
Show file tree
Hide file tree
Showing 6 changed files with 995 additions and 0 deletions.
5 changes: 5 additions & 0 deletions llama_hub/llama_packs/library.json
Original file line number Diff line number Diff line change
Expand Up @@ -270,5 +270,10 @@
"id": "llama_packs/vanna",
"author": "jerryjliu",
"keywords": ["vanna", "sql", "ai", "text-to-sql"]
},
"SelfRAGPack": {
"id": "llama_packs/self_rag",
"author": "mmaatouk",
"keywords": ["self-RAG", "llm", "smart-retreiver"]
}
}
63 changes: 63 additions & 0 deletions llama_hub/llama_packs/self_rag/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Simple self-RAG short form pack

This LlamaPack implements short form the [self-RAG paper by Akari et al.](https://arxiv.org/pdf/2310.11511.pdf).

Novel framework called Self-Reflective Retrieval-Augmented Generation (SELF-RAG). Which aims to enhance the quality and factuality of large language models (LLMs) by combining retrieval and self-reflection mechanisms.

The implementation is adapted from the author [implementation](https://github.com/AkariAsai/self-rag)
A full notebook guide can be found [here](https://github.com/run-llama/llama-hub/blob/main/llama_hub/llama_packs/self_rag/self_rag.ipynb).


## CLI Usage

You can download llamapacks directly using `llamaindex-cli`, which comes installed with the `llama-index` python package:

```bash
llamaindex-cli download-llamapack SelfRAGPack --download-dir ./self_rag_pack
```

You can then inspect the files at `./self_rag_pack` and use them as a template for your own project!

## Code Usage

We will show you how to import the agent from these files!
The implementation uses llama-cpp, to download the relevant models (be sure to replace DIR_PATH)
```bash
pip3 install -q huggingface-hub
huggingface-cli download m4r1/selfrag_llama2_7b-GGUF selfrag_llama2_7b.q4_k_m.gguf --local-dir "<DIR_PATH>" --local-dir-use-symlinks False
```

```python
from llama_index.llama_pack import download_llama_pack

# download and install dependencies
SelfRAGPack = download_llama_pack(
"SelfRAGPack", "./self_rag_pack"
)

```

From here, you can use the pack. You can import the relevant modules from the download folder (in the example below we assume it's a relative import or the directory
has been added to your system path).

```python
from self_rag_pack.base import SelfRAGQueryEngine

query_engine = SelfRAGQueryEngine(model_path=model_path, retriever=retreiver, verbose=True)

response = query_engine.query("Who won best Director in the 1972 Academy Awards?")
```

You can also use/initialize the pack directly.

```python
from llm_compiler_agent_pack.base import SelfRAGPack

agent_pack = SelfRAGPack(model_path=model_path, retriever=retreiver, verbose=True)
```

The `run()` function is a light wrapper around `agent.chat()`.

```python
response = pack.run("Who won best Director in the 1972 Academy Awards?")
```
Empty file.
280 changes: 280 additions & 0 deletions llama_hub/llama_packs/self_rag/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
from typing import Any, Dict, List, Tuple
import numpy as np

from llama_cpp import Llama
from llama_index import Response
from llama_index.llama_pack.base import BaseLlamaPack
from llama_index.bridge.pydantic import Field
from llama_index.query_engine import CustomQueryEngine
from llama_index.core.base_retriever import BaseRetriever
from llama_index.utils import print_text

_RELEVANCE_TOKENS = ["[Irrelevant]", "[Relevant]"]

_RETRIEVAL_TOKENS = ["[No Retrieval]", "[Retrieval]", "[Continue to Use Evidence]"]
_UTILITY_TOKENS = [
"[Utility:1]",
"[Utility:2]",
"[Utility:3]",
"[Utility:4]",
"[Utility:5]",
]
_GROUND_TOKENS = [
"[Fully supported]",
"[Partially supported]",
"[No support / Contradictory]",
]
_CTRL_TOKENS = [
"[Fully supported]",
"[Partially supported]",
"[No support / Contradictory]",
"[No Retrieval]",
"[Retrieval]",
"[Irrelevant]",
"[Relevant]",
"[Continue to Use Evidence]",
"<paragraph>",
"</paragraph>",
"[Utility:1]",
"[Utility:2]",
"[Utility:3]",
"[Utility:4]",
"[Utility:5]",
]

_MODEL_KWARGS = {"logits_all": True, "n_ctx": 2048, "n_gpu_layers": -1}
_GENERATE_KWARGS = {
"temperature": 0.0,
"top_p": 1.0,
"max_tokens": 50,
"logprobs": 32016,
}


def _format_prompt(input: str, paragraph: str = None) -> str:
prompt = "### Instruction:\n{0}\n\n### Response:\n".format(input)
if paragraph is not None:
prompt += "[Retrieval]<paragraph>{0}</paragraph>".format(paragraph)
return prompt


def _postprocess_answer(answer: str) -> str:
for token in _CTRL_TOKENS:
answer = answer.replace(token, "")

if "</s>" in answer:
answer = answer.replace("</s>", "")
if "\n" in answer:
answer = answer.replace("\n", "")

if "<|endoftext|>" in answer:
answer = answer.replace("<|endoftext|>", "")

return answer


def _relevance_score(pred_log_probs: Dict[str, float]) -> float:
"""Compute relevance score
Args:
pred_log_probs (Dict[str, float]): log probabilities of tokens
Returns:
float: relevance score
"""
rel_prob = np.exp(float(pred_log_probs["[Relevant]"]))
irel_prob = np.exp(float(pred_log_probs["[Irrelevant]"]))
return rel_prob / (rel_prob + irel_prob)


def _is_supported_score(
pred_tokens: List[int], pred_log_probs_dict: List[Dict[str, float]]
) -> float:
"""Compute support score
Args:
pred_tokens (List[int]): List of predicted tokens
pred_log_probs_dict (List[Dict[str, float]]): log probabilities of tokens for each predicted tokens
Returns:
float: support score
"""
isSup_score = 0
groundness_token_appear_id = -1
for tok_idx, token in enumerate(pred_tokens):
if token in _GROUND_TOKENS:
groundness_token_appear_id = tok_idx
break
if groundness_token_appear_id > -1:
grd_score_dict = {}
for token in _GROUND_TOKENS:
prob = pred_log_probs_dict[groundness_token_appear_id][token]
grd_score_dict[token] = np.exp(float(prob))
isSup_score = (
grd_score_dict["[Fully supported]"]
+ 0.5 * grd_score_dict["[Partially supported]"]
) / np.sum(list(grd_score_dict.values()))
return isSup_score


def _is_useful_score(
pred_tokens: List[int], pred_log_probs_dict: List[Dict[str, float]]
) -> float:
"""Compute usefulness score
Args:
pred_tokens (List[int]): List of predicted tokens
pred_log_probs_dict (List[Dict[str, float]]): log probabilities of tokens for each predicted tokens
Returns:
float: relevance score
"""
isUse_score = 0
utility_token_appear_id = -1
for tok_idx, tok in enumerate(pred_tokens):
if tok in _UTILITY_TOKENS:
utility_token_appear_id = tok_idx
if utility_token_appear_id > -1:
ut_score_dict = {}
for token in _UTILITY_TOKENS:
prob = pred_log_probs_dict[utility_token_appear_id][token]
ut_score_dict[token] = np.exp(float(prob))

ut_sum = np.sum(list(ut_score_dict.values()))
ut_weights = [-1, -0.5, 0, 0.5, 1]
isUse_score = np.sum(
[
ut_weights[i] * (ut_score_dict["[Utility:{}]".format(i + 1)] / ut_sum)
for i in range(len(ut_weights))
]
)
return isUse_score


class SelfRAGQueryEngine(CustomQueryEngine):
"""Simple short form self RAG query engine."""

llm: Llama = Field(default=None, description="llm")
generate_kwargs: Dict = Field(default=None, description="llm generation arguments")
retreiver: BaseRetriever = Field(default=None, description="Retreiver")
verbose: bool = Field(default=True, description="Verbose.")

def __init__(
self,
model_path: str,
retreiver: BaseRetriever,
verbose: bool = False,
model_kwargs: Dict = None,
generate_kwargs: Dict = None,
**kwargs: Any,
) -> None:
"""Init params."""
super().__init__(verbose=verbose, **kwargs)
model_kwargs = model_kwargs or _MODEL_KWARGS
self.generate_kwargs = generate_kwargs or _GENERATE_KWARGS

self.llm = Llama(model_path=model_path, verbose=verbose, **model_kwargs)
self.retreiver = retreiver

def _run_critic(
self, paragraphs: List[str]
) -> Tuple[Dict[int, float], Dict[int, str]]:
"""Run Critic component, the llm will generate responses based on the paragraphs and then evaluate them
Args:
paragraphs (List[str]): List of paragraphs to evaluate
Returns:
Tuple[Dict[int, float], Dict[int, str]]: Paragraphs final score and LLM predictions
"""
paragraphs_final_score = {}
llm_response_text = {}

for p_idx, paragraph in enumerate(paragraphs):
pred = self.llm(paragraph, **self.generate_kwargs)
# Cache llm answer
llm_response_text[p_idx] = pred["choices"][0]["text"]

logprobs = pred["choices"][0]["logprobs"]
pred_log_probs = logprobs["top_logprobs"]
# Compute isRel score, on the first predicted token
isRel_score = _relevance_score(pred_log_probs[0])

# Compute isSup score
isSup_score = _is_supported_score(logprobs["tokens"], pred_log_probs)

# Compute isUse score
isUse_score = _is_useful_score(logprobs["tokens"], pred_log_probs)

paragraphs_final_score[p_idx] = (
isRel_score + isSup_score + 0.5 * isUse_score
)
if self.verbose:
print_text(
f"Input: {paragraph}\nPrediction: {llm_response_text[p_idx]}\nScore: {paragraphs_final_score[p_idx]}\n",
color="blue",
)
print_text(
f"{p_idx + 1}/{len(paragraphs)} paragraphs done\n\n", color="blue"
)

return paragraphs_final_score, llm_response_text

def custom_query(self, query_str: str) -> Response:
"""Run self-RAG."""
response = self.llm(prompt=_format_prompt(query_str), **_GENERATE_KWARGS)
answer = response["choices"][0]["text"]
if "[Retrieval]" in answer:
if self.verbose:
print_text("Retreival required\n", color="blue")
documents = self.retreiver.retrieve(query_str)
if self.verbose:
print_text(f"Received: {len(documents)} documents\n", color="blue")
paragraphs = [
_format_prompt(query_str, document.node.text) for document in documents
]

if self.verbose:
print_text("Start evaluation\n", color="blue")
paragraphs_final_score, llm_response_text = self._run_critic(paragraphs)
if self.verbose:
print_text("End evaluation\n", color="blue")

best_paragraph_id = max(
paragraphs_final_score, key=paragraphs_final_score.get
)
answer = llm_response_text[best_paragraph_id]
if self.verbose:
print_text(f"Selected the best answer: {answer}\n", color="blue")

answer = _postprocess_answer(answer)
if self.verbose:
print_text(f"Final answer: {answer}\n", color="green")
return Response(response=str(answer))


class SelfRAGPack(BaseLlamaPack):
"""Simple short form Self-RAG pack."""

def __init__(
self,
model_path: str,
retreiver: BaseRetriever,
verbose: bool = False,
**kwargs: Any,
) -> None:
"""Init params."""

self.query_engine = SelfRAGQueryEngine(model_path, retreiver, verbose)

def get_modules(self) -> Dict[str, Any]:
"""Get modules."""
return {
"query_engine": self.query_engine,
"llm": self.query_engine.llm,
"retreiver": self.query_engine.retreiver,
}

def run(self, *args: Any, **kwargs: Any) -> Any:
"""Run the pipeline."""
return self.query_engine.query(*args, **kwargs)
1 change: 1 addition & 0 deletions llama_hub/llama_packs/self_rag/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
llama_cpp_python
Loading

0 comments on commit 801302d

Please sign in to comment.