From ae8d872eb7b5daaace48441fef360da03c85200f Mon Sep 17 00:00:00 2001 From: "shengzhe.li" Date: Wed, 31 Jul 2024 14:43:15 +0900 Subject: [PATCH] Fix bfloat bug --- src/jmteb/evaluators/sts/evaluator.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/jmteb/evaluators/sts/evaluator.py b/src/jmteb/evaluators/sts/evaluator.py index cbea7e2..b7b8eb8 100644 --- a/src/jmteb/evaluators/sts/evaluator.py +++ b/src/jmteb/evaluators/sts/evaluator.py @@ -109,6 +109,8 @@ def _compute_similarity( embeddings1: Tensor, embeddings2: Tensor, golden_scores: list, similarity_func: Callable ) -> tuple[dict[str, float], list[float]]: sim_scores = similarity_func(embeddings1, embeddings2).cpu() + if isinstance(sim_scores, Tensor) and sim_scores.dtype is torch.bfloat16: + sim_scores = sim_scores.float() pearson = pearsonr(golden_scores, sim_scores)[0] spearman = spearmanr(golden_scores, sim_scores)[0] return {