diff --git a/keras_wrapper/extra/evaluation.py b/keras_wrapper/extra/evaluation.py index ab2c1ae..403a936 100644 --- a/keras_wrapper/extra/evaluation.py +++ b/keras_wrapper/extra/evaluation.py @@ -4,6 +4,7 @@ from builtins import map, zip import json import logging + logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(message)s', datefmt='%d/%m/%Y %H:%M:%S') logger = logging.getLogger(__name__) @@ -12,6 +13,60 @@ # EVALUATION FUNCTIONS SELECTOR +def get_sacrebleu_score(pred_list, verbose, extra_vars, split): + """ + SacreBLEU! metrics + :param pred_list: dictionary of hypothesis sentences (id, sentence) + :param verbose: if greater than 0 the metric measures are printed out + :param extra_vars: extra variables, here are: + extra_vars['references'] - dict mapping sample indices to list with all valid captions (id, [sentences]) + extra_vars['tokenize_f'] - tokenization function used during model training (used again for validation) + extra_vars['detokenize_f'] - detokenization function used during model training (used again for validation) + extra_vars['tokenize_hypotheses'] - Whether tokenize or not the hypotheses during evaluation + :param split: split on which we are evaluating + :return: Dictionary with the coco scores + """ + import sacrebleu + gts = extra_vars[split]['references'] + if extra_vars.get('tokenize_hypotheses', False): + hypo = [list(map( + lambda x: extra_vars['tokenize_f'](x.strip()), line)) for line in pred_list] + else: + hypo = [line.strip() for line in pred_list] + + initial_references = gts.get(0) + if initial_references is None: + raise ValueError('You need to provide at least one reference') + + num_references = len(initial_references) + refs = [[] for _ in range(num_references)] + for references in gts.values(): + assert len(references) == num_references, '"get_sacrebleu_score" does not support a different number of references per sample.' + for ref_idx, reference in enumerate(references): + # De/Tokenize refereces if needed + tokenized_ref = extra_vars['tokenize_f'](reference) if extra_vars.get('tokenize_references', False)\ + else reference + detokenized_ref = extra_vars['detokenize_f'](tokenized_ref) if extra_vars.get('apply_detokenization', False) else tokenized_ref + refs[ref_idx].append(detokenized_ref) + + scorers = [ + (sacrebleu.corpus_bleu, "Bleu_4"), + ] + + final_scores = {} + for scorer, method in scorers: + score = scorer(hypo, refs) + final_scores[method] = score.score + + if verbose > 0: + logger.info('Computing SacreBleu scores on the %s split...' % split) + for metric in sorted(final_scores): + value = final_scores[metric] + logger.info(metric + ': ' + str(value)) + + return final_scores + + def get_coco_score(pred_list, verbose, extra_vars, split): """ COCO challenge metrics @@ -91,7 +146,7 @@ def eval_vqa(pred_list, verbose, extra_vars, split): import datetime import os from pycocoevalcap.vqa import vqaEval, visual_qa - from read_write import list2vqa + from keras_wrapper.extra.read_write import list2vqa quesFile = extra_vars[split]['quesFile'] annFile = extra_vars[split]['annFile'] @@ -103,7 +158,8 @@ def eval_vqa(pred_list, verbose, extra_vars, split): # create vqa object and vqaRes object vqa_ = visual_qa.VQA(annFile, quesFile) vqaRes = vqa_.loadRes(resFile, quesFile) - vqaEval_ = vqaEval.VQAEval(vqa_, vqaRes, n=2) # n is precision of accuracy (number of places after decimal), default is 2 + vqaEval_ = vqaEval.VQAEval(vqa_, vqaRes, + n=2) # n is precision of accuracy (number of places after decimal), default is 2 vqaEval_.evaluate() os.remove(resFile) # remove temporal file @@ -189,7 +245,8 @@ def multilabel_metrics(pred_list, verbose, extra_vars, split): if verbose > 0: logger.info( - '"coverage_error" (best: avg labels per sample = %f): %f' % (float(np.sum(y_gt)) / float(n_samples), coverr)) + '"coverage_error" (best: avg labels per sample = %f): %f' % ( + float(np.sum(y_gt)) / float(n_samples), coverr)) logger.info('Label Ranking "average_precision" (best: 1.0): %f' % avgprec) logger.info('Label "ranking_loss" (best: 0.0): %f' % rankloss) logger.info('precision: %f' % precision) @@ -204,9 +261,6 @@ def multilabel_metrics(pred_list, verbose, extra_vars, split): 'f1': f1} -import numpy as np - - def multiclass_metrics(pred_list, verbose, extra_vars, split): """ Multiclass classification metrics. See multilabel ranking metrics in sklearn library for more info: diff --git a/requirements.txt b/requirements.txt index 90a524e..482378d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ toolz cloudpickle matplotlib sacremoses +sacrebleu scipy future cython diff --git a/setup.py b/setup.py index 6bfc16a..16d31e3 100644 --- a/setup.py +++ b/setup.py @@ -30,6 +30,7 @@ 'cloudpickle', 'matplotlib', 'sacremoses', + 'sacrebleu', 'scipy', 'future', 'cython', diff --git a/tests/extra/test_wrapper_evaluation.py b/tests/extra/test_wrapper_evaluation.py index e3f7572..6b2843c 100644 --- a/tests/extra/test_wrapper_evaluation.py +++ b/tests/extra/test_wrapper_evaluation.py @@ -1,7 +1,25 @@ import pytest -from keras_wrapper.extra.evaluation import * +import numpy as np +from keras_wrapper.extra.evaluation import get_sacrebleu_score, get_coco_score, multilabel_metrics, compute_perplexity +def test_get_sacrebleu_score(): + pred_list = ['Prediction 1 X W Z', 'Prediction 2 X W Z', 'Prediction 3 X W Z'] + extra_vars = {'val': {'references': {0: ['Prediction 1 X W Z', 'Prediction 5'], + 1: ['Prediction 2 X W Z', 'X Y Z'], + 2: ['Prediction 3 X W Z', 'Prediction 5']}}, + + 'test': {'references': {0: ['Prediction 2 X W Z'], + 1: ['Prediction 3 X W Z'], + 2: ['Prediction 1 X W Z']}} + } + val_scores = get_sacrebleu_score(pred_list, 0, extra_vars, 'val') + assert np.allclose(val_scores['Bleu_4'], 100.0, atol=1e6) + + + test_scores = get_sacrebleu_score(pred_list, 0, extra_vars, 'test') + assert np.allclose(test_scores['Bleu_4'], 0., atol=1e6) + def test_get_coco_score(): pred_list = ['Prediction 1', 'Prediction 2', 'Prediction 3'] extra_vars = {'val': {'references': {0: ['Prediction 1'], 1: ['Prediction 2'],