Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Allow aggregated tasks within benchmarks #1771

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
1338736
fix: Allow aggregated tasks within benchmarks
KennethEnevoldsen Jan 11, 2025
f2920ff
feat: Update task filtering, fixing bug on MTEB
KennethEnevoldsen Jan 13, 2025
12aaa97
format
KennethEnevoldsen Jan 13, 2025
1be8ed8
remove "en-ext" from AmazonCounterfactualClassification
KennethEnevoldsen Jan 13, 2025
8aab5d0
fixed mteb(deu)
KennethEnevoldsen Jan 13, 2025
4dfe2ec
fix: simplify in a few areas
KennethEnevoldsen Jan 13, 2025
cd87ebb
wip
KennethEnevoldsen Jan 14, 2025
450953d
Merge branch 'correct-mteb-eng' into KennethEnevoldsen/issue-Allow-ag…
KennethEnevoldsen Jan 14, 2025
87816f1
tmp
KennethEnevoldsen Jan 15, 2025
f73ffb7
sav
KennethEnevoldsen Jan 16, 2025
33578ec
Allow aggregated tasks within benchmarks
KennethEnevoldsen Jan 17, 2025
54d16f9
Merge remote-tracking branch 'origin' into KennethEnevoldsen/issue-Al…
KennethEnevoldsen Jan 17, 2025
b11f6b1
ensure correct formatting of eval_langs
KennethEnevoldsen Jan 17, 2025
0718389
ignore aggregate dataset
KennethEnevoldsen Jan 17, 2025
2bc375c
clean up dummy cases
KennethEnevoldsen Jan 17, 2025
5a9bd8c
add to mteb(eng, classic)
KennethEnevoldsen Jan 17, 2025
36cee38
format
KennethEnevoldsen Jan 17, 2025
8bb9026
clean up
KennethEnevoldsen Jan 17, 2025
60a8f0f
Allow aggregated tasks within benchmarks
KennethEnevoldsen Jan 17, 2025
f65c68e
added fixed from comments
KennethEnevoldsen Jan 19, 2025
76d511c
Merge branch 'main' of https://github.com/embeddings-benchmark/mteb i…
KennethEnevoldsen Jan 19, 2025
14f3ae1
fix merge
KennethEnevoldsen Jan 19, 2025
66fb570
format
KennethEnevoldsen Jan 19, 2025
063e357
Updated task type
KennethEnevoldsen Jan 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mteb/abstasks/AbsTask.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class AbsTask(ABC):
dataset: dict[HFSubset, DatasetDict] | None = None # type: ignore
data_loaded: bool = False
is_multilingual: bool = False
hf_subsets: list[HFSubset] | None = None
hf_subsets: list[HFSubset]

def __init__(self, seed: int = 42, **kwargs: Any):
self.save_suffix = kwargs.get("save_suffix", "")
Expand All @@ -73,6 +73,7 @@ def __init__(self, seed: int = 42, **kwargs: Any):
np.random.seed(self.seed)
torch.manual_seed(self.seed)
torch.cuda.manual_seed_all(self.seed)
self.hf_subsets = list(self.metadata.hf_subsets_to_langscripts.keys())
KennethEnevoldsen marked this conversation as resolved.
Show resolved Hide resolved

def check_if_dataset_is_superseded(self):
"""Check if the dataset is superseded by a newer version"""
Expand Down
2 changes: 1 addition & 1 deletion mteb/abstasks/AbsTaskBitextMining.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def evaluate(
subsets_to_run: list[HFSubset] | None = None,
*,
encode_kwargs: dict[str, Any] = {},
**kwargs,
**kwargs: Any,
) -> dict[HFSubset, ScoresDict]:
if not self.data_loaded:
self.load_data()
Expand Down
23 changes: 19 additions & 4 deletions mteb/abstasks/TaskMetadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
"machine-translated and verified",
"machine-translated and localized",
"LM-generated and verified",
"multiple",
]

TASK_TYPE = Literal[
Expand Down Expand Up @@ -168,9 +169,10 @@
"gpl-3.0",
"cdla-sharing-1.0",
"mpl-2.0",
"multiple",
]
)

MODALITIES = Literal["text"]
METRIC_NAME = str
METRIC_VALUE = Union[int, float, dict[str, Any]]

Expand Down Expand Up @@ -227,13 +229,13 @@ class TaskMetadata(BaseModel):
bibtex_citation: The BibTeX citation for the dataset. Should be an empty string if no citation is available.
"""

dataset: dict
dataset: dict[str, Any]

name: str
description: str
prompt: str | PromptDict | None = None
type: TASK_TYPE
modalities: list[Literal["text"]] = ["text"]
modalities: list[MODALITIES] = ["text"]
category: TASK_CATEGORY | None = None
reference: STR_URL | None = None

Expand Down Expand Up @@ -334,6 +336,15 @@ def _check_language_code(code):
f"Invalid script code: {script}, you can find valid ISO 15924 codes in {path_to_lang_scripts}"
)

@property
def bcp47_codes(self) -> list[ISO_LANGUAGE_SCRIPT]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you introduce a new method for filtering languages?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not a new method it is a method for fetching languages in the bcp47 format (eng-Latn as opposed to eng). It is used to compute eval langs for the aggregated task (using just language code breaks the tests)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we need to standardize how we specify languages #1822, as the current approach is a bit problematic #1821 (comment)

"""Return the languages and script codes of the dataset formatting in accordance with the BCP-47 standard."""
if isinstance(self.eval_langs, dict):
return sorted(
{lang for langs in self.eval_langs.values() for lang in langs}
)
return sorted(set(self.eval_langs))

@property
def languages(self) -> list[str]:
"""Return the languages of the dataset as iso639-3 codes."""
Expand Down Expand Up @@ -420,8 +431,12 @@ def n_samples(self) -> dict[str, int] | None:
for subset, subset_value in stats.items():
if subset == "hf_subset_descriptive_stats":
continue
n_samples[subset] = subset_value["num_samples"]
n_samples[subset] = subset_value["num_samples"] # type: ignore
return n_samples

def __hash__(self) -> int:
return hash(self.model_dump_json())

@property
def revision(self) -> str:
return self.dataset["revision"]
171 changes: 171 additions & 0 deletions mteb/abstasks/aggregate_task_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
from __future__ import annotations

import logging
from datetime import datetime
from typing import Any, Literal

from pydantic import ConfigDict, model_validator

from mteb.abstasks.AbsTask import AbsTask
from mteb.abstasks.TaskMetadata import (
ANNOTATOR_TYPE,
LANGUAGES,
LICENSES,
MODALITIES,
SAMPLE_CREATION_METHOD,
STR_DATE,
TASK_DOMAIN,
TASK_SUBTYPE,
HFSubset,
TaskMetadata,
)
from mteb.languages import ISO_LANGUAGE_SCRIPT

logger = logging.getLogger(__name__)


class AggregateTaskMetadata(TaskMetadata):
"""Metadata for an aggregation of tasks. This description only covers exceptions to the TaskMetadata. Many of the field if not filled out will be
autofilled from its tasks.

Attributes:
name: The name of the aggregated task.
description: A description of the task. Should explain the aggregation.
prompt: An aggregate task does not have a prompt, thus this value is always None.
dataset: The dataset for the aggregated task is specified in its tasks. The aggregate task thus only specified the revision and uses a
placeholder path.
tasks: A list of tasks, the majority of the metadata is described within its tasks.
eval_splits: The splits of the tasks used for evaluation.
"""

model_config = ConfigDict(arbitrary_types_allowed=True)

name: str
description: str
dataset: dict[str, Any] = {
"path": "aggregate tasks do not have a path", # just a place holder
"revision": "1",
}

tasks: list[AbsTask]
main_score: str
type: Literal["aggregate-task"] = "aggregate-task"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This breaks task types on the leaderboard, and also the TASK_TYPES type definition. Can't we either make it a property or force people to specify a task type when they introduce an aggregate task?

eval_splits: list[str]
eval_langs: LANGUAGES = []
prompt: None = None
reference: str | None = None
bibtex_citation: str | None = None

@property
def hf_subsets_to_langscripts(self) -> dict[HFSubset, list[ISO_LANGUAGE_SCRIPT]]:
"""Return a dictionary mapping huggingface subsets to languages."""
return {"default": self.eval_langs} # type: ignore

@model_validator(mode="after") # type: ignore
def compute_unfilled_cases(self) -> AggregateTaskMetadata:
KennethEnevoldsen marked this conversation as resolved.
Show resolved Hide resolved
if not self.eval_langs:
self.eval_langs = self.compute_eval_langs()
if not self.date:
self.date = self.compute_date()
if not self.domains:
self.domains = self.compute_domains()
if not self.task_subtypes:
self.task_subtypes = self.compute_task_subtypes()
if not self.license:
self.license = self.compute_license()
if not self.annotations_creators:
self.annotations_creators = self.compute_annotations_creators()
if not self.dialect:
self.dialect = self.compute_dialect()
if not self.sample_creation:
self.sample_creation = self.compute_sample_creation()
if not self.modalities:
self.modalities = self.compute_modalities()

return self

def compute_eval_langs(self) -> list[ISO_LANGUAGE_SCRIPT]:
langs = set()
for task in self.tasks:
langs.update(set(task.metadata.bcp47_codes))
return list(langs)

def compute_date(self) -> tuple[STR_DATE, STR_DATE] | None:
# get min max date from tasks
dates = []
for task in self.tasks:
if task.metadata.date:
dates.append(datetime.fromisoformat(task.metadata.date[0]))
dates.append(datetime.fromisoformat(task.metadata.date[1]))

if not dates:
return None

min_date = min(dates)
max_date = max(dates)
return min_date.isoformat(), max_date.isoformat()

def compute_domains(self) -> list[TASK_DOMAIN] | None:
domains = set()
for task in self.tasks:
if task.metadata.domains:
domains.update(set(task.metadata.domains))
if domains:
return list(domains)
return None

def compute_task_subtypes(self) -> list[TASK_SUBTYPE] | None:
subtypes = set()
for task in self.tasks:
if task.metadata.task_subtypes:
subtypes.update(set(task.metadata.task_subtypes))
if subtypes:
return list(subtypes)
return None

def compute_license(self) -> LICENSES | None:
licenses = set()
for task in self.tasks:
if task.metadata.license:
licenses.add(task.metadata.license)
if len(licenses) > 1:
return "multiple"
return None

def compute_annotations_creators(self) -> ANNOTATOR_TYPE | None:
creators = set()
for task in self.tasks:
if task.metadata.annotations_creators:
creators.add(task.metadata.annotations_creators)
if len(creators) > 1:
logger.warning(
f"Multiple annotations_creators found for tasks in {self.name}. Using None as annotations_creators."
)
return None

def compute_dialect(self) -> list[str] | None:
dialects = set()
for task in self.tasks:
if task.metadata.dialect:
dialects.update(set(task.metadata.dialect))
if dialects:
return list(dialects)
return None

def compute_sample_creation(self) -> SAMPLE_CREATION_METHOD | None:
sample_creations = set()
for task in self.tasks:
if task.metadata.sample_creation:
sample_creations.add(task.metadata.sample_creation)
if len(sample_creations) > 1:
return "multiple"
return None

def compute_modalities(self) -> list[MODALITIES]:
modalities = set()
for task in self.tasks:
if task.metadata.modalities:
modalities.update(set(task.metadata.modalities))
if modalities:
return list(modalities)
return None
Loading
Loading