-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
31 changed files
with
2,473 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
**.pyc | ||
__pycache__/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,28 +1,79 @@ | ||
# Project | ||
# Introduction | ||
Source code for the paper Attention Satisfies: A Constraint-Satisfaction Lens on Factual Errors of Language Models; Mert Yuksekgonul, Varun Chandrasekaran, Erik Jones, Suriya Gunasekar, Ranjita Naik, Hamid Palangi, Ece Kamar, Besmira Nushi. ICLR 2024. | ||
|
||
> This repo has been populated by an initial template to help get you started. Please | ||
> make sure to update the content to build a great experience for community-building. | ||
# Getting Started | ||
The repository consists of: | ||
- Scripts for collecting attention-based metrics from open source models from the Llama-2 family at inference time. | ||
- Scripts for using these metrics across a given dataset and training a simple linear probe that can predict whether the model will make a factual error. | ||
- A simplistic tool for visualizing model attention on a given instance. | ||
- Relevant factual knowledge data that was used in the paper for evaluation | ||
|
||
As the maintainer of this project, please make a few updates: | ||
**Installation**:<br> | ||
The code is written in Python 3.11, and you can use `requirements.txt` to install the required packages. | ||
```bash | ||
conda create -n satprobe python=3.11 | ||
conda activate satprobe | ||
pip install -r requirements.txt | ||
``` | ||
|
||
- Improving this README.MD file to provide a great experience | ||
- Updating SUPPORT.MD with content about this project's support experience | ||
- Understanding the security reporting process in SECURITY.MD | ||
- Remove this section from the README | ||
## Datasets | ||
We provide the datasets used in the paper in the `factual_queries` folder. It is as simple to load as: | ||
```python | ||
from factual_queries import load_constraint_data | ||
items = load_constraint_data('basketball_players') | ||
print(items[0]) | ||
# {'player_name': 'Michael Jordan', 'label': 1963, 'prompt': 'User: Tell me the year the basketball player Michael Jordan was born in.\nAssistant: The player was born in', ... | ||
``` | ||
|
||
## Contributing | ||
## Attention Tools | ||
`model_lib/attention_tools.py` contains the code for collecting attention-based metrics from the Llama-2 family of models. It can be used as follows: | ||
```python | ||
from model_lib import HF_Llama2_Wrapper, run_attention_monitor | ||
tokenizer = transformers.AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") | ||
model = transformers.AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.bfloat16, device_map="cuda") | ||
|
||
This project welcomes contributions and suggestions. Most contributions require you to agree to a | ||
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us | ||
the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. | ||
model_wrapped = HF_Llama2_Wrapper(model, tokenizer, device="cuda") | ||
prompt_info = {"prompt": "The great Michael Jordan was born in", | ||
"constraints": [" Michael Jordan"]} | ||
data = run_attention_monitor(prompt_info, | ||
model_wrapped) | ||
``` | ||
This `data` object will contain the attention flow information for the given prompt and constraints. You can use this object to visualize the attention flow as follows: | ||
|
||
When you submit a pull request, a CLA bot will automatically determine whether you need to provide | ||
a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions | ||
provided by the bot. You will only need to do this once across all repos using our CLA. | ||
```python | ||
from viz_tools import plot_attention_flow | ||
# Collect the attention contribution data | ||
flow_matrix = data.all_token_contrib_norms[:, 1:data.num_prompt_tokens].T | ||
# Get token labels | ||
token_labels = data.token_labels[1:data.num_prompt_tokens] | ||
fig = plot_attention_flow(flow_matrix, token_labels) | ||
fig | ||
``` | ||
which would produce the following visualization: | ||
 | ||
|
||
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). | ||
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or | ||
contact [[email protected]](mailto:[email protected]) with any additional questions or comments. | ||
### Detecting Factual Errors | ||
Our probing experiments have 2 main steps: | ||
- Collect attention-based metrics for a given dataset. This is done using the `main_flow_collection.py`. | ||
- Train a simple linear probe on the collected metrics. This is done using the `main_probe.py`. | ||
|
||
An example of how to use these scripts is as follows: | ||
```bash | ||
python main_flow_collection.py --dataset_name basketball_players --model_name meta-llama/Llama-2-7b-hf --output_dir ./outputs | ||
python main_probe.py --dataset_name basketball_players --model_name meta-llama/Llama-2-7b-hf --output_dir ./outputs | ||
``` | ||
which would save the resulting figures and probe results in the `./outputs` folder. | ||
|
||
# Contact | ||
Mert Yuksekgonul ([email protected]) | ||
Besmira Nushi ([email protected]) | ||
|
||
# Contributing | ||
This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. | ||
|
||
When you submit a pull request, a CLA bot will automatically determine whether you need to provide a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions provided by the bot. You will only need to do this once across all repos using our CLA. | ||
|
||
This project has adopted the Microsoft Open Source Code of Conduct. For more information see the Code of Conduct FAQ or contact [email protected] with any additional questions or comments. | ||
|
||
## Trademarks | ||
|
||
|
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import numpy as np | ||
from .single_constraint import * | ||
from .multi_constraint import * | ||
|
||
def load_constraint_dataset(dataset_name, subsample_count=None): | ||
|
||
if dataset_name == "basketball_players": | ||
items = load_basketball_players() | ||
elif dataset_name == "football_teams": | ||
items = load_football_teams() | ||
elif dataset_name == "songs": | ||
items = load_songs() | ||
elif dataset_name == "movies": | ||
items = load_movies() | ||
elif "counterfact_" in dataset_name: | ||
items = load_counterfact_subset(dataset_name) | ||
elif dataset_name == "nobel": | ||
items = load_nobel_city() | ||
elif dataset_name == "words": | ||
items = load_word_startend() | ||
elif dataset_name == "books": | ||
items = load_books() | ||
else: | ||
raise ValueError(f"Unknown dataset {dataset_name}") | ||
|
||
if subsample_count is not None: | ||
items = np.random.choice(items, subsample_count, replace=False) | ||
|
||
return items |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import pickle | ||
import numpy as np | ||
|
||
### Multi-constraint datasets | ||
|
||
def load_nobel_city(): | ||
filename = "./factual_queries/data/nobel_multiconstraint.pkl" | ||
items = pickle.load(open(filename, "rb")) | ||
return items | ||
|
||
def load_word_startend(): | ||
filename = "./factual_queries/data/word_multiconstraint.pkl" | ||
items = pickle.load(open(filename, "rb")) | ||
return items | ||
|
||
def load_books(): | ||
filename = "./factual_queries/data/books_multiconstraint.pkl" | ||
items = pickle.load(open(filename, "rb")) | ||
return items |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import pickle | ||
import numpy as np | ||
|
||
def load_basketball_players(): | ||
with open("./factual_queries/data/basketball_players.pkl", "rb") as f: | ||
items = pickle.load(f) | ||
|
||
prompt_template = "Tell me the year the basketball player {} was born in." | ||
prompt_fn = lambda prompt: f"User: {prompt}\nAssistant: The player was born in" | ||
for item in items: | ||
item["constraint"] = item["player_name"] | ||
item["prompt"] = prompt_fn(prompt_template.format(item["constraint"])) | ||
item["label"] = item["birth_year"] | ||
item["popularity"] = item["popularity"] | ||
return items | ||
|
||
|
||
def load_football_teams(): | ||
with open("./factual_queries/data/football_teams.pkl", "rb") as f: | ||
items = pickle.load(f) | ||
|
||
prompt_template = "Tell me the year the football team {} was founded in." | ||
prompt_fn = lambda prompt: f"User: {prompt}\nAssistant: The team was founded in" | ||
for item in items: | ||
item["constraint"] = item["team_name"] | ||
item["prompt"] = prompt_fn(prompt_template.format(item["constraint"])) | ||
item["label"] = item["founding_year"] | ||
item["popularity"] = item["popularity"] | ||
return items | ||
|
||
|
||
def load_songs(): | ||
with open("./factual_queries/data/songs.pkl", "rb") as f: | ||
items = pickle.load(f) | ||
|
||
prompt_template = "Tell me the performer of the song {}" | ||
prompt_fn = lambda prompt: f"User: {prompt}\nAssistant: The performer is" | ||
for item in items: | ||
item["constraint"] = item["song_name"] | ||
item["prompt"] = prompt_fn(prompt_template.format(item["constraint"])) | ||
item["label"] = item["artist_name"] | ||
item["popularity"] = item["popularity"] | ||
return items | ||
|
||
|
||
def load_movies(): | ||
with open("./factual_queries/data/movies.pkl", "rb") as f: | ||
items = pickle.load(f) | ||
prompt_template = "Tell me the director of the movie {}." | ||
prompt_fn = lambda prompt: f"User: {prompt}\nAssistant: The director is" | ||
for item in items: | ||
item["constraint"] = item["movie_name"] | ||
item["prompt"] = prompt_fn(prompt_template.format(item["constraint"])) | ||
item["label"] = item["director_name"] | ||
return items | ||
|
||
def load_counterfact_subset(subset): | ||
filename = f"./factual_queries/data/{subset}.pkl" | ||
items = pickle.load(open(filename, "rb")) | ||
return items |
Oops, something went wrong.