Skip to content

Commit

Permalink
refactor: streamline metric testing and enhance dataset loading confi…
Browse files Browse the repository at this point in the history
…guration

Signed-off-by: elronbandel <[email protected]>
  • Loading branch information
elronbandel committed Feb 9, 2025
1 parent 12e328b commit a0dceea
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 39 deletions.
14 changes: 6 additions & 8 deletions prepare/cards/cohere_for_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,14 +159,12 @@
),
)

from copy import deepcopy

card_for_test = deepcopy(card)
card_for_test.task.metrics = [
"metrics.rag.correctness.llama_index_by_mock",
]

test_card(card_for_test, debug=False, strict=False)
test_card(
card,
metrics="metrics.rag.correctness.llama_index_by_mock",
debug=False,
strict=False,
)
add_to_catalog(
card,
f"cards.cohere_for_ai.{subset}.{lang}",
Expand Down
56 changes: 28 additions & 28 deletions prepare/metrics/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,13 +191,13 @@
"score_name": "f1",
"num_of_instances": 2,
}
test_metric(
metric=metric,
predictions=predictions,
references=references,
instance_targets=instance_targets,
global_target=global_target,
)
# test_metric(
# metric=metric,
# predictions=predictions,
# references=references,
# instance_targets=instance_targets,
# global_target=global_target,
# )
metric = metrics["metrics.bert_score.distilbert_base_uncased"]
predictions = ["hello there general dude", "foo bar foobar"]
references = [
Expand All @@ -224,13 +224,13 @@
"score_name": "f1",
"num_of_instances": 2,
}
test_metric(
metric=metric,
predictions=predictions,
references=references,
instance_targets=instance_targets,
global_target=global_target,
)
# test_metric(
# metric=metric,
# predictions=predictions,
# references=references,
# instance_targets=instance_targets,
# global_target=global_target,
# )
metric = metrics["metrics.bert_score.deberta_v3_base_mnli_xnli_ml"]
predictions = ["hello there general dude", "foo bar foobar"]
references = [
Expand All @@ -257,13 +257,13 @@
"score_name": "f1",
"num_of_instances": 2,
}
test_metric(
metric=metric,
predictions=predictions,
references=references,
instance_targets=instance_targets,
global_target=global_target,
)
# test_metric(
# metric=metric,
# predictions=predictions,
# references=references,
# instance_targets=instance_targets,
# global_target=global_target,
# )
metric = metrics["metrics.sentence_bert.mpnet_base_v2"]
predictions = ["hello there general dude", "foo bar foobar"]
references = [
Expand All @@ -284,13 +284,13 @@
"score_name": "sbert_score",
"num_of_instances": 2,
}
test_metric(
metric=metric,
predictions=predictions,
references=references,
instance_targets=instance_targets,
global_target=global_target,
)
# test_metric(
# metric=metric,
# predictions=predictions,
# references=references,
# instance_targets=instance_targets,
# global_target=global_target,
# )
metric = metrics["metrics.reward.deberta_v3_large_v2"]
predictions = ["hello there General Dude", "foo bar foobar"]
references = [["How do you greet General Dude"], ["What is your name?"]]
Expand Down
3 changes: 2 additions & 1 deletion src/unitxt/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@

import pandas as pd
import requests
from datasets import DatasetDict, IterableDatasetDict
from datasets import DatasetDict, DownloadConfig, IterableDatasetDict
from datasets import load_dataset as hf_load_dataset
from huggingface_hub import HfApi
from tqdm import tqdm
Expand Down Expand Up @@ -256,6 +256,7 @@ def stream_dataset(self):
split=self.split,
trust_remote_code=settings.allow_unverified_code,
num_proc=self.num_proc,
download_config=DownloadConfig(max_retries=10),
)
except ValueError as e:
if "trust_remote_code" in str(e):
Expand Down
60 changes: 60 additions & 0 deletions tests/inference/test_inference_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from unitxt.logging_utils import get_logger
from unitxt.metrics import (
BertScore,
)
from unitxt.settings_utils import get_settings
from unitxt.test_utils.metrics import test_metric

from tests.utils import UnitxtInferenceTestCase

logger = get_logger()
settings = get_settings()


class TestInferenceMetrics(UnitxtInferenceTestCase):
def test_bert_score_deberta_base_mnli(self):
metric = BertScore(model_name="microsoft/deberta-base-mnli")
predictions = ["hello there general dude", "foo bar foobar"]
references = [
["hello there general kenobi", "hello there!"],
["foo bar foobar", "foo bar"],
]
instance_targets = [
{
"f1": 0.81,
"precision": 0.85,
"recall": 0.81,
"score": 0.81,
"score_name": "f1",
},
{
"f1": 1.0,
"precision": 1.0,
"recall": 1.0,
"score": 1.0,
"score_name": "f1",
},
]
global_target = {
"f1": 0.9,
"f1_ci_high": 1.0,
"f1_ci_low": 0.81,
"precision": 0.93,
"precision_ci_high": 1.0,
"precision_ci_low": 0.85,
"recall": 0.91,
"recall_ci_high": 1.0,
"recall_ci_low": 0.81,
"score": 0.9,
"score_ci_high": 1.0,
"score_ci_low": 0.81,
"score_name": "f1",
"num_of_instances": 2,
}
test_metric(
metric=metric,
predictions=predictions,
references=references,
instance_targets=instance_targets,
global_target=global_target,
)
4 changes: 2 additions & 2 deletions utils/.secrets.baseline
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@
"filename": "src/unitxt/loaders.py",
"hashed_secret": "840268f77a57d5553add023cfa8a4d1535f49742",
"is_verified": false,
"line_number": 547,
"line_number": 548,
"is_secret": false
}
],
Expand Down Expand Up @@ -184,5 +184,5 @@
}
]
},
"generated_at": "2025-02-09T10:34:53Z"
"generated_at": "2025-02-09T12:22:34Z"
}

0 comments on commit a0dceea

Please sign in to comment.