From a23a41e5ea7eb4d80c24db228005464456b2b7f4 Mon Sep 17 00:00:00 2001 From: elronbandel Date: Mon, 6 Jan 2025 17:55:53 +0200 Subject: [PATCH] Add new processors and update type hints for response assessment tasks Signed-off-by: elronbandel --- prepare/processors/processors.py | 17 +++++++++++ .../pairwise_comparison/multi_turn.py | 9 ++++-- .../multi_turn_with_reference.py | 9 ++++-- .../pairwise_comparison/single_turn.py | 9 ++++-- .../single_turn_with_reference.py | 9 ++++-- src/unitxt/processors.py | 29 ++++++++++++++++++- tests/library/test_postprocessors.py | 2 +- utils/.secrets.baseline | 6 ++-- 8 files changed, 75 insertions(+), 15 deletions(-) diff --git a/prepare/processors/processors.py b/prepare/processors/processors.py index feb66be69a..9884b86a22 100644 --- a/prepare/processors/processors.py +++ b/prepare/processors/processors.py @@ -29,12 +29,15 @@ TakeFirstNonEmptyLine, TakeFirstWord, TakeLastNonEmptyLine, + TakeUntilPunc, + Title, ToYesOrNone, Upper, YesNoToInt, YesToOneElseZero, ) from unitxt.settings_utils import get_constants +from unitxt.string_operators import Strip constants = get_constants() logger = get_logger() @@ -74,6 +77,12 @@ def add_processor_and_operator_to_catalog( overwrite=True, ) +add_processor_and_operator_to_catalog( + artifact_name="strip", + operator=Strip(), + overwrite=True, +) + add_processor_and_operator_to_catalog( artifact_name="take_last_non_empty_line", operator=TakeLastNonEmptyLine(), @@ -91,6 +100,14 @@ def add_processor_and_operator_to_catalog( artifact_name="lower_case_till_punc", operator=LowerCaseTillPunc(), overwrite=True ) +add_processor_and_operator_to_catalog( + artifact_name="take_until_punc", operator=TakeUntilPunc(), overwrite=True +) + +add_processor_and_operator_to_catalog( + artifact_name="title", operator=Title(), overwrite=True +) + add_processor_and_operator_to_catalog( artifact_name="hate_speech_or_not_hate_speech", operator=StringEquals(string="hate speech"), diff --git a/prepare/tasks/response_assessment/pairwise_comparison/multi_turn.py b/prepare/tasks/response_assessment/pairwise_comparison/multi_turn.py index c4c6f259cb..2c825c10db 100644 --- a/prepare/tasks/response_assessment/pairwise_comparison/multi_turn.py +++ b/prepare/tasks/response_assessment/pairwise_comparison/multi_turn.py @@ -1,4 +1,4 @@ -from typing import List, Tuple +from typing import List, Literal, Tuple from unitxt.blocks import Task from unitxt.catalog import add_to_catalog @@ -10,9 +10,12 @@ "dialog_b": List[Tuple[str, str]], }, reference_fields={ - "winner": str - }, # TODO: Support and change to "Literal['choice_a', 'choice_b', 'tie']"}, + "winner": Literal["choice_a", "choice_b", "tie"], + "classes": List[Literal["choice_a", "choice_b", "tie"]], + }, + defaults={"classes": ["choice_a", "choice_b", "tie"]}, metrics=["metrics.accuracy", "metrics.f1_micro", "metrics.f1_macro"], + prediction_type=str, ), "tasks.response_assessment.pairwise_comparison.multi_turn", overwrite=True, diff --git a/prepare/tasks/response_assessment/pairwise_comparison/multi_turn_with_reference.py b/prepare/tasks/response_assessment/pairwise_comparison/multi_turn_with_reference.py index 967e2104f3..a7171813a5 100644 --- a/prepare/tasks/response_assessment/pairwise_comparison/multi_turn_with_reference.py +++ b/prepare/tasks/response_assessment/pairwise_comparison/multi_turn_with_reference.py @@ -1,4 +1,4 @@ -from typing import List, Tuple +from typing import List, Literal, Tuple from unitxt.blocks import Task from unitxt.catalog import add_to_catalog @@ -11,9 +11,12 @@ "reference_dialog": List[Tuple[str, str]], }, reference_fields={ - "winner": str - }, # TODO: Support and change to "Literal['choice_a', 'choice_b', 'tie']"}, + "winner": Literal["choice_a", "choice_b", "tie"], + "classes": List[Literal["choice_a", "choice_b", "tie"]], + }, + defaults={"classes": ["choice_a", "choice_b", "tie"]}, metrics=["metrics.accuracy", "metrics.f1_micro", "metrics.f1_macro"], + prediction_type=str, ), "tasks.response_assessment.pairwise_comparison.multi_turn_with_reference", overwrite=True, diff --git a/prepare/tasks/response_assessment/pairwise_comparison/single_turn.py b/prepare/tasks/response_assessment/pairwise_comparison/single_turn.py index 629f08fa6e..cb33a8db4c 100644 --- a/prepare/tasks/response_assessment/pairwise_comparison/single_turn.py +++ b/prepare/tasks/response_assessment/pairwise_comparison/single_turn.py @@ -1,3 +1,5 @@ +from typing import List, Literal + from unitxt.blocks import Task from unitxt.catalog import add_to_catalog @@ -9,9 +11,12 @@ "answer_b": str, }, reference_fields={ - "winner": str - }, # TODO: Support and change to "Literal['choice_a', 'choice_b', 'tie']" + "winner": Literal["choice_a", "choice_b", "tie"], + "classes": List[Literal["choice_a", "choice_b", "tie"]], + }, + defaults={"classes": ["choice_a", "choice_b", "tie"]}, metrics=["metrics.accuracy", "metrics.f1_micro", "metrics.f1_macro"], + prediction_type=str, ), "tasks.response_assessment.pairwise_comparison.single_turn", overwrite=True, diff --git a/prepare/tasks/response_assessment/pairwise_comparison/single_turn_with_reference.py b/prepare/tasks/response_assessment/pairwise_comparison/single_turn_with_reference.py index 2ec825ffed..b7eb081391 100644 --- a/prepare/tasks/response_assessment/pairwise_comparison/single_turn_with_reference.py +++ b/prepare/tasks/response_assessment/pairwise_comparison/single_turn_with_reference.py @@ -1,3 +1,5 @@ +from typing import List, Literal + from unitxt.blocks import Task from unitxt.catalog import add_to_catalog @@ -10,9 +12,12 @@ "reference_answer": str, }, reference_fields={ - "winner": str - }, # TODO: Support and change to "Literal['choice_a', 'choice_b', 'tie']"}, + "winner": Literal["choice_a", "choice_b", "tie"], + "classes": List[Literal["choice_a", "choice_b", "tie"]], + }, + defaults={"classes": ["choice_a", "choice_b", "tie"]}, metrics=["metrics.accuracy", "metrics.f1_micro", "metrics.f1_macro"], + prediction_type=str, ), "tasks.response_assessment.pairwise_comparison.single_turn_with_reference", overwrite=True, diff --git a/src/unitxt/processors.py b/src/unitxt/processors.py index 5880af56b4..03e8675026 100644 --- a/src/unitxt/processors.py +++ b/src/unitxt/processors.py @@ -170,6 +170,27 @@ def process_value(self, text: Any) -> Any: return str(text).upper() +class Title(FieldOperator): + def process_value(self, text: Any) -> Any: + return str(text).title() + + +class TakeUntilPunc(FieldOperator): + _requirements_list = ["regex"] + + def prepare(self): + super().prepare() + import regex + + self.pattern = regex.compile(r"\p{P}+") + + def process_value(self, text: Any) -> Any: + match = self.pattern.search(text) + if match: + text = text[: match.start()] + return text + + @deprecation("2.0.0", alternative=Lower) class LowerCase(Lower): pass @@ -294,10 +315,16 @@ def process_value(self, text: Any) -> Any: class ExtractMtBenchLabelJudgment(FieldOperator): + options = { + "A": "choice_a", + "B": "choice_b", + "C": "tie", + } + def process_value(self, text: Any) -> Any: match = re.search(r"\[\[([^\]]+)\]\]", text) try: - return str(match.group(1)) + return self.options.get(str(match.group(1)), "None") except: return "None" diff --git a/tests/library/test_postprocessors.py b/tests/library/test_postprocessors.py index 9f1c44c10a..152caa6408 100644 --- a/tests/library/test_postprocessors.py +++ b/tests/library/test_postprocessors.py @@ -395,7 +395,7 @@ def test_extract_mt_bench_label_judgment(self): "good", "bad [[C]]", ] - targets = ["A", "B", "A", "None", "C"] + targets = ["choice_a", "choice_b", "choice_a", "None", "tie"] check_operator( operator=postprocessor, diff --git a/utils/.secrets.baseline b/utils/.secrets.baseline index a03b0cf0b5..cf33869c7e 100644 --- a/utils/.secrets.baseline +++ b/utils/.secrets.baseline @@ -133,7 +133,7 @@ "filename": "src/unitxt/inference.py", "hashed_secret": "aa6cd2a77de22303be80e1f632195d62d211a729", "is_verified": false, - "line_number": 1250, + "line_number": 1249, "is_secret": false }, { @@ -141,7 +141,7 @@ "filename": "src/unitxt/inference.py", "hashed_secret": "c8f16a194efc59559549c7bd69f7bea038742e79", "is_verified": false, - "line_number": 1696, + "line_number": 1663, "is_secret": false } ], @@ -184,5 +184,5 @@ } ] }, - "generated_at": "2025-01-05T08:58:59Z" + "generated_at": "2025-01-01T10:22:19Z" }