Skip to content

Commit

Permalink
initial commit for release
Browse files Browse the repository at this point in the history
  • Loading branch information
mertyg committed Apr 2, 2024
1 parent 8018efb commit c944746
Show file tree
Hide file tree
Showing 31 changed files with 2,473 additions and 18 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
**.pyc
__pycache__/
87 changes: 69 additions & 18 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,28 +1,79 @@
# Project
# Introduction
Source code for the paper Attention Satisfies: A Constraint-Satisfaction Lens on Factual Errors of Language Models; Mert Yuksekgonul, Varun Chandrasekaran, Erik Jones, Suriya Gunasekar, Ranjita Naik, Hamid Palangi, Ece Kamar, Besmira Nushi. ICLR 2024.

> This repo has been populated by an initial template to help get you started. Please
> make sure to update the content to build a great experience for community-building.
# Getting Started
The repository consists of:
- Scripts for collecting attention-based metrics from open source models from the Llama-2 family at inference time.
- Scripts for using these metrics across a given dataset and training a simple linear probe that can predict whether the model will make a factual error.
- A simplistic tool for visualizing model attention on a given instance.
- Relevant factual knowledge data that was used in the paper for evaluation

As the maintainer of this project, please make a few updates:
**Installation**:<br>
The code is written in Python 3.11, and you can use `requirements.txt` to install the required packages.
```bash
conda create -n satprobe python=3.11
conda activate satprobe
pip install -r requirements.txt
```

- Improving this README.MD file to provide a great experience
- Updating SUPPORT.MD with content about this project's support experience
- Understanding the security reporting process in SECURITY.MD
- Remove this section from the README
## Datasets
We provide the datasets used in the paper in the `factual_queries` folder. It is as simple to load as:
```python
from factual_queries import load_constraint_data
items = load_constraint_data('basketball_players')
print(items[0])
# {'player_name': 'Michael Jordan', 'label': 1963, 'prompt': 'User: Tell me the year the basketball player Michael Jordan was born in.\nAssistant: The player was born in', ...
```

## Contributing
## Attention Tools
`model_lib/attention_tools.py` contains the code for collecting attention-based metrics from the Llama-2 family of models. It can be used as follows:
```python
from model_lib import HF_Llama2_Wrapper, run_attention_monitor
tokenizer = transformers.AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
model = transformers.AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.bfloat16, device_map="cuda")

This project welcomes contributions and suggestions. Most contributions require you to agree to a
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
model_wrapped = HF_Llama2_Wrapper(model, tokenizer, device="cuda")
prompt_info = {"prompt": "The great Michael Jordan was born in",
"constraints": [" Michael Jordan"]}
data = run_attention_monitor(prompt_info,
model_wrapped)
```
This `data` object will contain the attention flow information for the given prompt and constraints. You can use this object to visualize the attention flow as follows:

When you submit a pull request, a CLA bot will automatically determine whether you need to provide
a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
provided by the bot. You will only need to do this once across all repos using our CLA.
```python
from viz_tools import plot_attention_flow
# Collect the attention contribution data
flow_matrix = data.all_token_contrib_norms[:, 1:data.num_prompt_tokens].T
# Get token labels
token_labels = data.token_labels[1:data.num_prompt_tokens]
fig = plot_attention_flow(flow_matrix, token_labels)
fig
```
which would produce the following visualization:
![attention_flow](./assets/sample_jordan.png)

This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
contact [[email protected]](mailto:[email protected]) with any additional questions or comments.
### Detecting Factual Errors
Our probing experiments have 2 main steps:
- Collect attention-based metrics for a given dataset. This is done using the `main_flow_collection.py`.
- Train a simple linear probe on the collected metrics. This is done using the `main_probe.py`.

An example of how to use these scripts is as follows:
```bash
python main_flow_collection.py --dataset_name basketball_players --model_name meta-llama/Llama-2-7b-hf --output_dir ./outputs
python main_probe.py --dataset_name basketball_players --model_name meta-llama/Llama-2-7b-hf --output_dir ./outputs
```
which would save the resulting figures and probe results in the `./outputs` folder.

# Contact
Mert Yuksekgonul ([email protected])
Besmira Nushi ([email protected])

# Contributing
This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.

When you submit a pull request, a CLA bot will automatically determine whether you need to provide a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions provided by the bot. You will only need to do this once across all repos using our CLA.

This project has adopted the Microsoft Open Source Code of Conduct. For more information see the Code of Conduct FAQ or contact [email protected] with any additional questions or comments.

## Trademarks

Expand Down
Binary file added assets/sample_jordan.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
29 changes: 29 additions & 0 deletions factual_queries/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import numpy as np
from .single_constraint import *
from .multi_constraint import *

def load_constraint_dataset(dataset_name, subsample_count=None):

if dataset_name == "basketball_players":
items = load_basketball_players()
elif dataset_name == "football_teams":
items = load_football_teams()
elif dataset_name == "songs":
items = load_songs()
elif dataset_name == "movies":
items = load_movies()
elif "counterfact_" in dataset_name:
items = load_counterfact_subset(dataset_name)
elif dataset_name == "nobel":
items = load_nobel_city()
elif dataset_name == "words":
items = load_word_startend()
elif dataset_name == "books":
items = load_books()
else:
raise ValueError(f"Unknown dataset {dataset_name}")

if subsample_count is not None:
items = np.random.choice(items, subsample_count, replace=False)

return items
Binary file added factual_queries/data/basketball_players.pkl
Binary file not shown.
Binary file added factual_queries/data/books_multiconstraint.pkl
Binary file not shown.
Binary file added factual_queries/data/counterfact_citizenship.pkl
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added factual_queries/data/football_teams.pkl
Binary file not shown.
Binary file added factual_queries/data/movies.pkl
Binary file not shown.
Binary file added factual_queries/data/nobel_multiconstraint.pkl
Binary file not shown.
Binary file added factual_queries/data/schools.pkl
Binary file not shown.
Binary file added factual_queries/data/songs.pkl
Binary file not shown.
Binary file added factual_queries/data/word_multiconstraint.pkl
Binary file not shown.
19 changes: 19 additions & 0 deletions factual_queries/multi_constraint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import pickle
import numpy as np

### Multi-constraint datasets

def load_nobel_city():
filename = "./factual_queries/data/nobel_multiconstraint.pkl"
items = pickle.load(open(filename, "rb"))
return items

def load_word_startend():
filename = "./factual_queries/data/word_multiconstraint.pkl"
items = pickle.load(open(filename, "rb"))
return items

def load_books():
filename = "./factual_queries/data/books_multiconstraint.pkl"
items = pickle.load(open(filename, "rb"))
return items
60 changes: 60 additions & 0 deletions factual_queries/single_constraint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import pickle
import numpy as np

def load_basketball_players():
with open("./factual_queries/data/basketball_players.pkl", "rb") as f:
items = pickle.load(f)

prompt_template = "Tell me the year the basketball player {} was born in."
prompt_fn = lambda prompt: f"User: {prompt}\nAssistant: The player was born in"
for item in items:
item["constraint"] = item["player_name"]
item["prompt"] = prompt_fn(prompt_template.format(item["constraint"]))
item["label"] = item["birth_year"]
item["popularity"] = item["popularity"]
return items


def load_football_teams():
with open("./factual_queries/data/football_teams.pkl", "rb") as f:
items = pickle.load(f)

prompt_template = "Tell me the year the football team {} was founded in."
prompt_fn = lambda prompt: f"User: {prompt}\nAssistant: The team was founded in"
for item in items:
item["constraint"] = item["team_name"]
item["prompt"] = prompt_fn(prompt_template.format(item["constraint"]))
item["label"] = item["founding_year"]
item["popularity"] = item["popularity"]
return items


def load_songs():
with open("./factual_queries/data/songs.pkl", "rb") as f:
items = pickle.load(f)

prompt_template = "Tell me the performer of the song {}"
prompt_fn = lambda prompt: f"User: {prompt}\nAssistant: The performer is"
for item in items:
item["constraint"] = item["song_name"]
item["prompt"] = prompt_fn(prompt_template.format(item["constraint"]))
item["label"] = item["artist_name"]
item["popularity"] = item["popularity"]
return items


def load_movies():
with open("./factual_queries/data/movies.pkl", "rb") as f:
items = pickle.load(f)
prompt_template = "Tell me the director of the movie {}."
prompt_fn = lambda prompt: f"User: {prompt}\nAssistant: The director is"
for item in items:
item["constraint"] = item["movie_name"]
item["prompt"] = prompt_fn(prompt_template.format(item["constraint"]))
item["label"] = item["director_name"]
return items

def load_counterfact_subset(subset):
filename = f"./factual_queries/data/{subset}.pkl"
items = pickle.load(open(filename, "rb"))
return items
Loading

0 comments on commit c944746

Please sign in to comment.