-
Notifications
You must be signed in to change notification settings - Fork 52
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor: streamline metric testing and enhance dataset loading confi…
…guration Signed-off-by: elronbandel <[email protected]>
- Loading branch information
1 parent
12e328b
commit a0dceea
Showing
5 changed files
with
98 additions
and
39 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from unitxt.logging_utils import get_logger | ||
from unitxt.metrics import ( | ||
BertScore, | ||
) | ||
from unitxt.settings_utils import get_settings | ||
from unitxt.test_utils.metrics import test_metric | ||
|
||
from tests.utils import UnitxtInferenceTestCase | ||
|
||
logger = get_logger() | ||
settings = get_settings() | ||
|
||
|
||
class TestInferenceMetrics(UnitxtInferenceTestCase): | ||
def test_bert_score_deberta_base_mnli(self): | ||
metric = BertScore(model_name="microsoft/deberta-base-mnli") | ||
predictions = ["hello there general dude", "foo bar foobar"] | ||
references = [ | ||
["hello there general kenobi", "hello there!"], | ||
["foo bar foobar", "foo bar"], | ||
] | ||
instance_targets = [ | ||
{ | ||
"f1": 0.81, | ||
"precision": 0.85, | ||
"recall": 0.81, | ||
"score": 0.81, | ||
"score_name": "f1", | ||
}, | ||
{ | ||
"f1": 1.0, | ||
"precision": 1.0, | ||
"recall": 1.0, | ||
"score": 1.0, | ||
"score_name": "f1", | ||
}, | ||
] | ||
global_target = { | ||
"f1": 0.9, | ||
"f1_ci_high": 1.0, | ||
"f1_ci_low": 0.81, | ||
"precision": 0.93, | ||
"precision_ci_high": 1.0, | ||
"precision_ci_low": 0.85, | ||
"recall": 0.91, | ||
"recall_ci_high": 1.0, | ||
"recall_ci_low": 0.81, | ||
"score": 0.9, | ||
"score_ci_high": 1.0, | ||
"score_ci_low": 0.81, | ||
"score_name": "f1", | ||
"num_of_instances": 2, | ||
} | ||
test_metric( | ||
metric=metric, | ||
predictions=predictions, | ||
references=references, | ||
instance_targets=instance_targets, | ||
global_target=global_target, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters