Skip to content

Commit

Permalink
clean accuracy code
Browse files Browse the repository at this point in the history
  • Loading branch information
jennhu committed Oct 3, 2019
1 parent b575404 commit 8cdba09
Show file tree
Hide file tree
Showing 27 changed files with 270 additions and 450 deletions.
38 changes: 30 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ training corpus. See the paper for more details on how we
constructed our materials.

### Vocabulary issues
In all of our novel materials (**TODO: list the experiment names**), the
In our novel materials (used in `['exp2-rc-all', 'exp3-comp', 'exp4-pp']`), the
lexical items are designed to be in-vocabulary for models trained on the
Penn Treebank. This is not the case for the materials used in Experiment 1, the
[Marvin & Linzen (2018)](https://arxiv.org/abs/1808.09031) replication.
Expand All @@ -73,11 +73,16 @@ Penn Treebank. This is not the case for the materials used in Experiment 1, the
The per-token surprisal values for each model can be found in the [data](data)
folder, following this naming convention:
```
data/<MODEL>/<EXPERIMENT>/<PRONOUN>_<MODEL>.txt
data/surprisal/<MODEL>/<EXPERIMENT>/<PRONOUN>_<MODEL>.txt
```
The BERT data is in a slightly different `.csv` format, but otherwise
follows the same naming convention.

The accuracy results can be found at
```
data/accuracy/<EXPERIMENT>.csv
```

## Dependencies
Our analysis code requires a basic scientific installation of Python
(`numpy`, `pandas`, `matplotlib`, `seaborn`, etc.).
Expand All @@ -100,14 +105,18 @@ We can make the training script for our n-gram model available upon request.
## Reproducing our results

### Figures
To generate the plots for a given experiment and model, run the following:
To generate the plots for a given experiment and list of models, run the following:

```bash
cd analysis
mkdir figures
python generate_lot.py -o figures -model <MODELS> -exp <EXPERIMENT> -vs
mkdir -p figures
python generate_plot.py -o figures -model <MODELS> -exp <EXPERIMENT> -vs
```
This will save a plot to `analysis/figures/<EXPERIMENT>_<MODEL>.png`.
This will save a plot to `analysis/figures/<EXPERIMENT>-<MODELS>.png`.
Note that `<MODELS>` can be a list of model names (e.g. `-model rnng bert jrnn`),
`'big'` for large-vocabulary models, or `'all'` for all models. The
large-vocabulary models are **BERT, Transformer-XL, JRNN, GRNN, 5-gram**.

The `-vs` flag specifies to plot the negative log probability **differential**.
You can omit the flag to plot the raw negative log probabilities.

Expand All @@ -117,9 +126,22 @@ if it does not exist):

```bash
cd analysis
./plot_all figures
./get_figures figures
```

### Accuracy

**TODO**
Similarly, to compute the accuracy for a given experiment and list of models,
run:
```bash
cd analysis
mkdir -p accuracy
python compute_accuracy.py -o accuracy -model <MODELS> -exp <EXPERIMENT>
```
This will save a file to `analysis/accuracy/<EXPERIMENT>-<MODELS>.csv`.

To compute the accuracy for all our experiments, run the following:
```bash
cd analysis
./get_accuracy accuracy
```
164 changes: 0 additions & 164 deletions analysis/accuracy.py

This file was deleted.

120 changes: 120 additions & 0 deletions analysis/compute_accuracy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""
accuracy.py
Get accuracy results.
"""
import argparse
from pathlib import Path
from numpy import mean
import random
import pandas as pd

import utils

def get_accuracy(df, distractor_pos):
item_list = df.item.unique()
n_items = len(item_list)
num_correct_vs_baseline = 0
num_correct_vs_distractor = 0
num_correct = 0

for item in item_list:
item_rows = df[df.item == item]
baseline_rows = item_rows[item_rows.mismatch_position == 'none']
distractor_rows = item_rows[item_rows.mismatch_position == distractor_pos]
ungrammatical_rows = item_rows[item_rows.grammatical == 0]

vs_baseline = ungrammatical_rows.surprisal.mean() - baseline_rows.surprisal.mean()
vs_distractor = ungrammatical_rows.surprisal.mean() - distractor_rows.surprisal.mean()

# Check if ungrammatical - baseline is positive.
if vs_baseline > 0:
num_correct_vs_baseline += 1

# Check if ungrammatical - distractor is positive.
if vs_distractor > 0:
num_correct_vs_distractor += 1

# Check if both differentials are positive.
if vs_baseline > 0 and vs_distractor > 0:
num_correct += 1

# If both differentials are zero, then label correct with probability 1/3.
elif vs_baseline == 0 and vs_distractor == 0:
choice = random.choice(['baseline', 'distractor', 'ungrammatical'])
if choice == 'ungrammatical':
num_correct += 1

# Calculate proportion of items where different accuracy conditions hold.
vs_baseline_acc = num_correct_vs_baseline / float(n_items)
vs_distractor_acc = num_correct_vs_distractor / float(n_items)
total_acc = num_correct / float(n_items)

return total_acc, vs_baseline_acc, vs_distractor_acc

#################################################################################
# Main function -- partially shared with generate_plot.py
#################################################################################

def main(args):
# Get list of model names.
if args.model == ['all']:
model_list = utils.MODELS
elif args.model == ['big']:
model_list = utils.BIG_MODELS
else:
model_list = args.model

# Ensure only large-vocabulary models are specified for M&L replication.
if 'ml' in args.exp and any(m not in utils.BIG_MODELS for m in model_list):
raise ValueError(
'Only large-vocabulary models are compatible with '
'Marvin & Linzen\'s (2018) materials. '
'Please use "--model big" to plot the results from that experiment.'
)

# Assign file name based on name of experiment and specified models.
out_path = Path(f'{args.out_prefix}/{args.exp}-{"_".join(args.model)}.csv')

acc_dict = []
for model in model_list:
# Get data for each pronoun for current model.
for pn in utils.PRONOUNS:
surp_ext = 'csv' if model == 'bert' else 'txt'
surp_path = Path(
f'../data/surprisal/{model}/{args.exp}/{pn}_{model}.{surp_ext}'
)
if model == 'bert':
pn_df = pd.read_csv(surp_path)
else:
data_path = Path(f'../stimuli/{args.exp}/{pn}.csv')
pn_df = utils.get_data_df(data_path, surp_path, args.exp, pn)

# Assign appropriate mismatch position for distractor condition.
if 'rc' in args.exp:
distractor_pos = 'rc_subj'
elif 'comp' in args.exp or 'ml' in args.exp:
distractor_pos = 'nonlocal_subj'
else:
distractor_pos = 'distractor'

total_acc, vs_baseline_acc, vs_distractor_acc = get_accuracy(
pn_df, distractor_pos
)
acc_dict.append(dict(
model=model, total_acc=total_acc, exp=args.exp, pronoun=pn,
vs_baseline_acc=vs_baseline_acc, vs_distractor_acc=vs_distractor_acc
))
acc_df = pd.DataFrame(acc_dict)
acc_df.to_csv(out_path, index=False)


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Compute accuracy for models.')
parser.add_argument('--out_prefix', '-out_prefix', '--o', '-o',
help='prefix to path to save final .csv file '
'(file will be named according to experiment)')
parser.add_argument('--model', '-model', '--m', '-m', nargs='+',
help='list of model names, or "all" or "big"')
parser.add_argument('--exp', '-exp', help='name of experiment')
args = parser.parse_args()
main(args)
21 changes: 21 additions & 0 deletions analysis/get_accuracy
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#!/bin/bash

if [ "$#" -ne 1 ]; then
echo "Expected usage: ./get_accuracy <output_folder>"
fi

mkdir -p $1

ML_EXPS=("exp1a-ml-rc" "exp1b-ml-comp")
OTHER_EXPS=("exp2-rc" "exp3-comp" "exp4-pp")
EXPS=("${ML_EXPS[@]}" "${OTHER_EXPS[@]}")

for exp in ${EXPS[@]}; do
echo "== Computing accuracy for $exp =="
if [[ " ${ML_EXPS[*]} " == *" $exp "* ]]; then
model="big"
else
model="all"
fi
python compute_accuracy.py -o $1 -model $model -exp $exp
done
2 changes: 1 addition & 1 deletion analysis/plot_all → analysis/get_figures
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash

if [ "$#" -ne 1 ]; then
echo "Expected usage: ./plot_all <output_folder>"
echo "Expected usage: ./get_figures <output_folder>"
fi

mkdir -p $1
Expand Down
Loading

0 comments on commit 8cdba09

Please sign in to comment.