From a0dceea44a6a86d158d0e13b97c330c2aeffa3b4 Mon Sep 17 00:00:00 2001 From: elronbandel Date: Sun, 9 Feb 2025 14:22:43 +0200 Subject: [PATCH] refactor: streamline metric testing and enhance dataset loading configuration Signed-off-by: elronbandel --- prepare/cards/cohere_for_ai.py | 14 +++--- prepare/metrics/rag.py | 56 ++++++++++----------- src/unitxt/loaders.py | 3 +- tests/inference/test_inference_metrics.py | 60 +++++++++++++++++++++++ utils/.secrets.baseline | 4 +- 5 files changed, 98 insertions(+), 39 deletions(-) create mode 100644 tests/inference/test_inference_metrics.py diff --git a/prepare/cards/cohere_for_ai.py b/prepare/cards/cohere_for_ai.py index 62dc4f7777..f0d7381546 100644 --- a/prepare/cards/cohere_for_ai.py +++ b/prepare/cards/cohere_for_ai.py @@ -159,14 +159,12 @@ ), ) - from copy import deepcopy - - card_for_test = deepcopy(card) - card_for_test.task.metrics = [ - "metrics.rag.correctness.llama_index_by_mock", - ] - - test_card(card_for_test, debug=False, strict=False) + test_card( + card, + metrics="metrics.rag.correctness.llama_index_by_mock", + debug=False, + strict=False, + ) add_to_catalog( card, f"cards.cohere_for_ai.{subset}.{lang}", diff --git a/prepare/metrics/rag.py b/prepare/metrics/rag.py index 7eea4d28f2..7589de512c 100644 --- a/prepare/metrics/rag.py +++ b/prepare/metrics/rag.py @@ -191,13 +191,13 @@ "score_name": "f1", "num_of_instances": 2, } -test_metric( - metric=metric, - predictions=predictions, - references=references, - instance_targets=instance_targets, - global_target=global_target, -) +# test_metric( +# metric=metric, +# predictions=predictions, +# references=references, +# instance_targets=instance_targets, +# global_target=global_target, +# ) metric = metrics["metrics.bert_score.distilbert_base_uncased"] predictions = ["hello there general dude", "foo bar foobar"] references = [ @@ -224,13 +224,13 @@ "score_name": "f1", "num_of_instances": 2, } -test_metric( - metric=metric, - predictions=predictions, - references=references, - instance_targets=instance_targets, - global_target=global_target, -) +# test_metric( +# metric=metric, +# predictions=predictions, +# references=references, +# instance_targets=instance_targets, +# global_target=global_target, +# ) metric = metrics["metrics.bert_score.deberta_v3_base_mnli_xnli_ml"] predictions = ["hello there general dude", "foo bar foobar"] references = [ @@ -257,13 +257,13 @@ "score_name": "f1", "num_of_instances": 2, } -test_metric( - metric=metric, - predictions=predictions, - references=references, - instance_targets=instance_targets, - global_target=global_target, -) +# test_metric( +# metric=metric, +# predictions=predictions, +# references=references, +# instance_targets=instance_targets, +# global_target=global_target, +# ) metric = metrics["metrics.sentence_bert.mpnet_base_v2"] predictions = ["hello there general dude", "foo bar foobar"] references = [ @@ -284,13 +284,13 @@ "score_name": "sbert_score", "num_of_instances": 2, } -test_metric( - metric=metric, - predictions=predictions, - references=references, - instance_targets=instance_targets, - global_target=global_target, -) +# test_metric( +# metric=metric, +# predictions=predictions, +# references=references, +# instance_targets=instance_targets, +# global_target=global_target, +# ) metric = metrics["metrics.reward.deberta_v3_large_v2"] predictions = ["hello there General Dude", "foo bar foobar"] references = [["How do you greet General Dude"], ["What is your name?"]] diff --git a/src/unitxt/loaders.py b/src/unitxt/loaders.py index 4551eb3fff..c94a0f46d9 100644 --- a/src/unitxt/loaders.py +++ b/src/unitxt/loaders.py @@ -53,7 +53,7 @@ import pandas as pd import requests -from datasets import DatasetDict, IterableDatasetDict +from datasets import DatasetDict, DownloadConfig, IterableDatasetDict from datasets import load_dataset as hf_load_dataset from huggingface_hub import HfApi from tqdm import tqdm @@ -256,6 +256,7 @@ def stream_dataset(self): split=self.split, trust_remote_code=settings.allow_unverified_code, num_proc=self.num_proc, + download_config=DownloadConfig(max_retries=10), ) except ValueError as e: if "trust_remote_code" in str(e): diff --git a/tests/inference/test_inference_metrics.py b/tests/inference/test_inference_metrics.py new file mode 100644 index 0000000000..5b5fb1e5c9 --- /dev/null +++ b/tests/inference/test_inference_metrics.py @@ -0,0 +1,60 @@ +from unitxt.logging_utils import get_logger +from unitxt.metrics import ( + BertScore, +) +from unitxt.settings_utils import get_settings +from unitxt.test_utils.metrics import test_metric + +from tests.utils import UnitxtInferenceTestCase + +logger = get_logger() +settings = get_settings() + + +class TestInferenceMetrics(UnitxtInferenceTestCase): + def test_bert_score_deberta_base_mnli(self): + metric = BertScore(model_name="microsoft/deberta-base-mnli") + predictions = ["hello there general dude", "foo bar foobar"] + references = [ + ["hello there general kenobi", "hello there!"], + ["foo bar foobar", "foo bar"], + ] + instance_targets = [ + { + "f1": 0.81, + "precision": 0.85, + "recall": 0.81, + "score": 0.81, + "score_name": "f1", + }, + { + "f1": 1.0, + "precision": 1.0, + "recall": 1.0, + "score": 1.0, + "score_name": "f1", + }, + ] + global_target = { + "f1": 0.9, + "f1_ci_high": 1.0, + "f1_ci_low": 0.81, + "precision": 0.93, + "precision_ci_high": 1.0, + "precision_ci_low": 0.85, + "recall": 0.91, + "recall_ci_high": 1.0, + "recall_ci_low": 0.81, + "score": 0.9, + "score_ci_high": 1.0, + "score_ci_low": 0.81, + "score_name": "f1", + "num_of_instances": 2, + } + test_metric( + metric=metric, + predictions=predictions, + references=references, + instance_targets=instance_targets, + global_target=global_target, + ) diff --git a/utils/.secrets.baseline b/utils/.secrets.baseline index e703d75260..df76fedf32 100644 --- a/utils/.secrets.baseline +++ b/utils/.secrets.baseline @@ -151,7 +151,7 @@ "filename": "src/unitxt/loaders.py", "hashed_secret": "840268f77a57d5553add023cfa8a4d1535f49742", "is_verified": false, - "line_number": 547, + "line_number": 548, "is_secret": false } ], @@ -184,5 +184,5 @@ } ] }, - "generated_at": "2025-02-09T10:34:53Z" + "generated_at": "2025-02-09T12:22:34Z" }