Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Allow aggregated tasks within benchmarks #1771

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
1338736
fix: Allow aggregated tasks within benchmarks
KennethEnevoldsen Jan 11, 2025
f2920ff
feat: Update task filtering, fixing bug on MTEB
KennethEnevoldsen Jan 13, 2025
12aaa97
format
KennethEnevoldsen Jan 13, 2025
1be8ed8
remove "en-ext" from AmazonCounterfactualClassification
KennethEnevoldsen Jan 13, 2025
8aab5d0
fixed mteb(deu)
KennethEnevoldsen Jan 13, 2025
4dfe2ec
fix: simplify in a few areas
KennethEnevoldsen Jan 13, 2025
cd87ebb
wip
KennethEnevoldsen Jan 14, 2025
450953d
Merge branch 'correct-mteb-eng' into KennethEnevoldsen/issue-Allow-ag…
KennethEnevoldsen Jan 14, 2025
87816f1
tmp
KennethEnevoldsen Jan 15, 2025
f73ffb7
sav
KennethEnevoldsen Jan 16, 2025
33578ec
Allow aggregated tasks within benchmarks
KennethEnevoldsen Jan 17, 2025
54d16f9
Merge remote-tracking branch 'origin' into KennethEnevoldsen/issue-Al…
KennethEnevoldsen Jan 17, 2025
b11f6b1
ensure correct formatting of eval_langs
KennethEnevoldsen Jan 17, 2025
0718389
ignore aggregate dataset
KennethEnevoldsen Jan 17, 2025
2bc375c
clean up dummy cases
KennethEnevoldsen Jan 17, 2025
5a9bd8c
add to mteb(eng, classic)
KennethEnevoldsen Jan 17, 2025
36cee38
format
KennethEnevoldsen Jan 17, 2025
8bb9026
clean up
KennethEnevoldsen Jan 17, 2025
60a8f0f
Allow aggregated tasks within benchmarks
KennethEnevoldsen Jan 17, 2025
f65c68e
added fixed from comments
KennethEnevoldsen Jan 19, 2025
76d511c
Merge branch 'main' of https://github.com/embeddings-benchmark/mteb i…
KennethEnevoldsen Jan 19, 2025
14f3ae1
fix merge
KennethEnevoldsen Jan 19, 2025
66fb570
format
KennethEnevoldsen Jan 19, 2025
063e357
Updated task type
KennethEnevoldsen Jan 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mteb/abstasks/AbsTaskBitextMining.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def evaluate(
subsets_to_run: list[HFSubset] | None = None,
*,
encode_kwargs: dict[str, Any] = {},
**kwargs,
**kwargs: Any,
) -> dict[HFSubset, ScoresDict]:
if not self.data_loaded:
self.load_data()
Expand Down
3 changes: 3 additions & 0 deletions mteb/abstasks/TaskMetadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
AnyUrl,
BaseModel,
BeforeValidator,
ConfigDict,
TypeAdapter,
field_validator,
)
Expand Down Expand Up @@ -227,6 +228,8 @@ class TaskMetadata(BaseModel):
bibtex_citation: The BibTeX citation for the dataset. Should be an empty string if no citation is available.
"""

model_config = ConfigDict(arbitrary_types_allowed=True)

dataset: dict

name: str
Expand Down
128 changes: 128 additions & 0 deletions mteb/abstasks/aggregated_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from __future__ import annotations

import logging
import random
from typing import Any

import numpy as np
import torch
from datasets import Dataset
from pydantic import field_validator

from mteb.abstasks.TaskMetadata import DescriptiveStatistics, HFSubset, TaskMetadata
from mteb.encoder_interface import Encoder
from mteb.load_results.task_results import TaskResult

from .AbsTask import AbsTask, ScoresDict

logger = logging.getLogger(__name__)


class AggregatedTaskMetadata(TaskMetadata):
"""A derivative of the taskmetadata used for aggregated of tasks. Can e.g. be used to create custom tasks
which are a combination of existing task. For an example see CQADupstackRetrieval.

The attributes are the same as TaskMetadata, with a few exceptions described below.

Attributes:
dataset: Always None as the task dataset is specified in its subtasks
prompt: Always None as the task prompt is specified in its subtasks
tasks: A list of tasks
"""

dataset: None = None
prompt: None = None
tasks: list[AbsTask]

@field_validator("dataset")
def _check_dataset_path_is_specified(
cls, dataset: dict[str, Any]
) -> dict[str, Any]:
return dataset # skip validation

@field_validator("dataset")
def _check_dataset_revision_is_specified(
cls, dataset: dict[str, Any]
) -> dict[str, Any]:
return dataset # skip validation

@field_validator("prompt")
def _check_prompt_is_valid(cls, prompt: None) -> None:
return prompt # skip validation


class AbsTaskAggregated(AbsTask):
metadata: AggregatedTaskMetadata
abstask_prompt: None = None

def __init__(self, seed: int = 42, **kwargs: Any):
self.tasks = self.metadata.tasks
self.save_suffix = kwargs.get("save_suffix", "")

self.seed = seed
random.seed(self.seed)
np.random.seed(self.seed)
torch.manual_seed(self.seed)
torch.cuda.manual_seed_all(self.seed)

def evaluate(
KennethEnevoldsen marked this conversation as resolved.
Show resolved Hide resolved
self,
model: Encoder,
split: str = "test",
subsets_to_run: list[HFSubset] | None = None,
*,
encode_kwargs: dict[str, Any] = {},
mteb_kwargs: dict[str, Any] = {},
**kwargs: Any,
) -> dict[HFSubset, ScoresDict]:
from mteb.evaluation.MTEB import MTEB # to prevent circular imports

if subsets_to_run:
logger.warning(
"Specifying which subset to run is not supported for aggregated tasks. It will be ignored."
)

bench = MTEB(tasks=self.tasks)
task_results = bench.run(
model=model,
encode_kwargs=encode_kwargs,
eval_subsets=None,
eval_splits=[split],
verbosity=0,
**mteb_kwargs,
)
return {"default": self.task_results_to_score(task_results)}

def task_results_to_score(self, task_results: list[TaskResult]) -> ScoresDict:
main_scores = []
for task_res in task_results:
main_scores.append(
task_res.get_score(
getter=lambda scores: scores[self.metadata.main_score]
)
)
return {self.metadata.main_score: np.mean(main_scores)}

def load_data(self, **kwargs: Any) -> None:
for task in self.tasks:
task.load_data()

self.data_loaded = True

def _evaluate_subset(
self,
model: Encoder,
data_split: Dataset,
*,
parallel: bool = False,
encode_kwargs: dict[str, Any] = {},
**kwargs,
) -> ScoresDict:
raise NotImplementedError()

def _calculate_metrics_from_split(
self, split: str, hf_subset: str | None = None, compute_overall: bool = False
) -> DescriptiveStatistics:
# it is a bit annoying that we have remove
# functionality from a class. Let me know if you have a better way to doing this.
raise NotImplementedError()
2 changes: 1 addition & 1 deletion mteb/evaluation/MTEB.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(
)

if tasks is not None:
self._tasks = tasks
self._tasks: Iterable[str | AbsTask] = tasks
if isinstance(tasks[0], Benchmark):
self.benchmarks = tasks
self._tasks = list(chain.from_iterable(tasks))
x-tabdeveloping marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
1 change: 1 addition & 0 deletions mteb/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from .aggregated_tasks import *
from .BitextMining import *
from .Classification import *
from .Clustering import *
Expand Down
72 changes: 72 additions & 0 deletions mteb/tasks/aggregated_tasks/CQADupStackRetrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from __future__ import annotations

from mteb.abstasks import AbsTask
from mteb.abstasks.aggregated_task import AbsTaskAggregated, AggregatedTaskMetadata
from mteb.tasks.Retrieval import (
CQADupstackAndroidRetrieval,
CQADupstackEnglishRetrieval,
CQADupstackGamingRetrieval,
CQADupstackGisRetrieval,
CQADupstackMathematicaRetrieval,
CQADupstackPhysicsRetrieval,
CQADupstackProgrammersRetrieval,
CQADupstackStatsRetrieval,
CQADupstackTexRetrieval,
CQADupstackUnixRetrieval,
CQADupstackWebmastersRetrieval,
CQADupstackWordpressRetrieval,
)

task_list_cqa: list[AbsTask] = [
CQADupstackAndroidRetrieval(),
CQADupstackEnglishRetrieval(),
CQADupstackGamingRetrieval(),
CQADupstackGisRetrieval(),
CQADupstackMathematicaRetrieval(),
CQADupstackPhysicsRetrieval(),
CQADupstackProgrammersRetrieval(),
CQADupstackStatsRetrieval(),
CQADupstackTexRetrieval(),
CQADupstackUnixRetrieval(),
CQADupstackWebmastersRetrieval(),
CQADupstackWordpressRetrieval(),
]


class CQADupstackRetrieval(AbsTaskAggregated):
metadata = AggregatedTaskMetadata(
name="CQADupstackRetrieval",
description="CQADupStack: A Benchmark Data Set for Community Question-Answering Research",
reference="http://nlp.cis.unimelb.edu.au/resources/cqadupstack/",
tasks=task_list_cqa,
type="Retrieval",
category="s2p",
modalities=["text"],
eval_splits=["test"],
eval_langs=["eng-Latn"],
main_score="ndcg_at_10",
date=("2010-01-01", "2014-01-01"), # 2010 is start of stackexchange
domains=["Written", "Non-fiction"],
task_subtypes=["Question answering"],
license="apache-2.0",
annotations_creators="derived",
dialect=[],
sample_creation="found",
bibtex_citation="""@inproceedings{hoogeveen2015,
author = {Hoogeveen, Doris and Verspoor, Karin M. and Baldwin, Timothy},
title = {CQADupStack: A Benchmark Data Set for Community Question-Answering Research},
booktitle = {Proceedings of the 20th Australasian Document Computing Symposium (ADCS)},
series = {ADCS '15},
year = {2015},
isbn = {978-1-4503-4040-3},
location = {Parramatta, NSW, Australia},
pages = {3:1--3:8},
articleno = {3},
numpages = {8},
url = {http://doi.acm.org/10.1145/2838931.2838934},
doi = {10.1145/2838931.2838934},
acmid = {2838934},
publisher = {ACM},
address = {New York, NY, USA},
}""",
)
5 changes: 5 additions & 0 deletions mteb/tasks/aggregated_tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from __future__ import annotations

from .CQADupStackRetrieval import CQADupstackRetrieval

__all__ = ["CQADupstackRetrieval"]
Loading