Skip to content

Commit

Permalink
Merge pull request #223 from lvapeab/master
Browse files Browse the repository at this point in the history
Add evaluation using SacreBleu
  • Loading branch information
lvapeab authored Mar 24, 2020
2 parents fe7f41b + 50ab7b0 commit 6581c68
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 7 deletions.
66 changes: 60 additions & 6 deletions keras_wrapper/extra/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from builtins import map, zip
import json
import logging

logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(message)s', datefmt='%d/%m/%Y %H:%M:%S')
logger = logging.getLogger(__name__)

Expand All @@ -12,6 +13,60 @@

# EVALUATION FUNCTIONS SELECTOR

def get_sacrebleu_score(pred_list, verbose, extra_vars, split):
"""
SacreBLEU! metrics
:param pred_list: dictionary of hypothesis sentences (id, sentence)
:param verbose: if greater than 0 the metric measures are printed out
:param extra_vars: extra variables, here are:
extra_vars['references'] - dict mapping sample indices to list with all valid captions (id, [sentences])
extra_vars['tokenize_f'] - tokenization function used during model training (used again for validation)
extra_vars['detokenize_f'] - detokenization function used during model training (used again for validation)
extra_vars['tokenize_hypotheses'] - Whether tokenize or not the hypotheses during evaluation
:param split: split on which we are evaluating
:return: Dictionary with the coco scores
"""
import sacrebleu
gts = extra_vars[split]['references']
if extra_vars.get('tokenize_hypotheses', False):
hypo = [list(map(
lambda x: extra_vars['tokenize_f'](x.strip()), line)) for line in pred_list]
else:
hypo = [line.strip() for line in pred_list]

initial_references = gts.get(0)
if initial_references is None:
raise ValueError('You need to provide at least one reference')

num_references = len(initial_references)
refs = [[] for _ in range(num_references)]
for references in gts.values():
assert len(references) == num_references, '"get_sacrebleu_score" does not support a different number of references per sample.'
for ref_idx, reference in enumerate(references):
# De/Tokenize refereces if needed
tokenized_ref = extra_vars['tokenize_f'](reference) if extra_vars.get('tokenize_references', False)\
else reference
detokenized_ref = extra_vars['detokenize_f'](tokenized_ref) if extra_vars.get('apply_detokenization', False) else tokenized_ref
refs[ref_idx].append(detokenized_ref)

scorers = [
(sacrebleu.corpus_bleu, "Bleu_4"),
]

final_scores = {}
for scorer, method in scorers:
score = scorer(hypo, refs)
final_scores[method] = score.score

if verbose > 0:
logger.info('Computing SacreBleu scores on the %s split...' % split)
for metric in sorted(final_scores):
value = final_scores[metric]
logger.info(metric + ': ' + str(value))

return final_scores


def get_coco_score(pred_list, verbose, extra_vars, split):
"""
COCO challenge metrics
Expand Down Expand Up @@ -91,7 +146,7 @@ def eval_vqa(pred_list, verbose, extra_vars, split):
import datetime
import os
from pycocoevalcap.vqa import vqaEval, visual_qa
from read_write import list2vqa
from keras_wrapper.extra.read_write import list2vqa

quesFile = extra_vars[split]['quesFile']
annFile = extra_vars[split]['annFile']
Expand All @@ -103,7 +158,8 @@ def eval_vqa(pred_list, verbose, extra_vars, split):
# create vqa object and vqaRes object
vqa_ = visual_qa.VQA(annFile, quesFile)
vqaRes = vqa_.loadRes(resFile, quesFile)
vqaEval_ = vqaEval.VQAEval(vqa_, vqaRes, n=2) # n is precision of accuracy (number of places after decimal), default is 2
vqaEval_ = vqaEval.VQAEval(vqa_, vqaRes,
n=2) # n is precision of accuracy (number of places after decimal), default is 2
vqaEval_.evaluate()
os.remove(resFile) # remove temporal file

Expand Down Expand Up @@ -189,7 +245,8 @@ def multilabel_metrics(pred_list, verbose, extra_vars, split):

if verbose > 0:
logger.info(
'"coverage_error" (best: avg labels per sample = %f): %f' % (float(np.sum(y_gt)) / float(n_samples), coverr))
'"coverage_error" (best: avg labels per sample = %f): %f' % (
float(np.sum(y_gt)) / float(n_samples), coverr))
logger.info('Label Ranking "average_precision" (best: 1.0): %f' % avgprec)
logger.info('Label "ranking_loss" (best: 0.0): %f' % rankloss)
logger.info('precision: %f' % precision)
Expand All @@ -204,9 +261,6 @@ def multilabel_metrics(pred_list, verbose, extra_vars, split):
'f1': f1}


import numpy as np


def multiclass_metrics(pred_list, verbose, extra_vars, split):
"""
Multiclass classification metrics. See multilabel ranking metrics in sklearn library for more info:
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ toolz
cloudpickle
matplotlib
sacremoses
sacrebleu
scipy
future
cython
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
'cloudpickle',
'matplotlib',
'sacremoses',
'sacrebleu',
'scipy',
'future',
'cython',
Expand Down
20 changes: 19 additions & 1 deletion tests/extra/test_wrapper_evaluation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,25 @@
import pytest
from keras_wrapper.extra.evaluation import *
import numpy as np
from keras_wrapper.extra.evaluation import get_sacrebleu_score, get_coco_score, multilabel_metrics, compute_perplexity


def test_get_sacrebleu_score():
pred_list = ['Prediction 1 X W Z', 'Prediction 2 X W Z', 'Prediction 3 X W Z']
extra_vars = {'val': {'references': {0: ['Prediction 1 X W Z', 'Prediction 5'],
1: ['Prediction 2 X W Z', 'X Y Z'],
2: ['Prediction 3 X W Z', 'Prediction 5']}},

'test': {'references': {0: ['Prediction 2 X W Z'],
1: ['Prediction 3 X W Z'],
2: ['Prediction 1 X W Z']}}
}
val_scores = get_sacrebleu_score(pred_list, 0, extra_vars, 'val')
assert np.allclose(val_scores['Bleu_4'], 100.0, atol=1e6)


test_scores = get_sacrebleu_score(pred_list, 0, extra_vars, 'test')
assert np.allclose(test_scores['Bleu_4'], 0., atol=1e6)

def test_get_coco_score():
pred_list = ['Prediction 1', 'Prediction 2', 'Prediction 3']
extra_vars = {'val': {'references': {0: ['Prediction 1'], 1: ['Prediction 2'],
Expand Down

0 comments on commit 6581c68

Please sign in to comment.