diff --git a/torchseq/metric_hooks/sep_ae.py b/torchseq/metric_hooks/sep_ae.py index 4ea6271..14ccd54 100644 --- a/torchseq/metric_hooks/sep_ae.py +++ b/torchseq/metric_hooks/sep_ae.py @@ -855,6 +855,19 @@ def eval_gen_codepred_diversity( topk_outputs.append(output) + # calculate p-BLEU (Cao and Wan, 2020) + # p-BLEU = sum_i, sum_{j neq i} BLEU(yi, yj) / k * (k-1) + pbleu_scores = [] + for i in range(top_k): + for j in range(top_k): + if i == j: + continue + this_bleu = sacrebleu.corpus_bleu( + topk_outputs[i], list(zip(*[[x] for x in topk_outputs[j]])), lowercase=True + ).score + pbleu_scores.append(this_bleu) + scores["pbleu"] = np.mean(pbleu_scores) + agent.config.eval.data["sample_outputs"] = sample_outputs return scores, topk_outputs