From ed89867e1ad8c0312e6295988b9d80a21026bee0 Mon Sep 17 00:00:00 2001 From: Ryokan Ri Date: Fri, 10 May 2024 13:04:36 +0900 Subject: [PATCH 1/5] add --eval_include option --- src/jmteb/__main__.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/jmteb/__main__.py b/src/jmteb/__main__.py index 360f5e8..3dfdb5d 100644 --- a/src/jmteb/__main__.py +++ b/src/jmteb/__main__.py @@ -57,17 +57,38 @@ def main( parser.add_argument("--config", action=ActionConfigFile, help="Path to the config file.") parser.add_argument("--save_dir", type=str, default=None, help="Directory to save the outputs") parser.add_argument("--overwrite_cache", type=bool, default=False, help="Overwrite the save_dir if it exists") + parser.add_argument("--eval_include", type=list[str], default=None, help="Evaluators to include.") parser.add_argument("--eval_exclude", type=list[str], default=None, help="Evaluators to exclude.") args = parser.parse_args() + if args.eval_include is not None: + # check if the specified evaluators are valid + evaluator_keys = list(args.evaluators.keys()) + for include_key in args.eval_include: + if include_key not in evaluator_keys: + raise ValueError(f"Invalid evaluator name: {include_key}") + + # remove evaluators not in eval_include + for key in evaluator_keys: + if key not in args.eval_include: + args.evaluators.pop(key) + if args.eval_exclude is not None: + # check if the specified evaluators are valid evaluator_keys = list(args.evaluators.keys()) + for exclude_key in args.eval_exclude: + if exclude_key not in evaluator_keys: + raise ValueError(f"Invalid evaluator name: {exclude_key}") + # remove evaluators in eval_exclude for key in evaluator_keys: if key in args.eval_exclude: args.evaluators.pop(key) + if len(args.evaluators) == 0: + raise ValueError("No evaluator is selected. Please check the config file or the command line arguments.") + args = parser.instantiate_classes(args) if isinstance(args.evaluators, str): raise ValueError( From f520a7388d70d4ba4fa968a04a32e701d1ba2332 Mon Sep 17 00:00:00 2001 From: Ryokan Ri Date: Fri, 10 May 2024 13:39:53 +0900 Subject: [PATCH 2/5] use DummyTextEmbedder for test embedder --- tests/conftest.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 6faf7e4..9a104d9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,9 @@ +from __future__ import annotations + +import numpy as np import pytest -from jmteb.embedders.sbert_embedder import SentenceBertEmbedder +from jmteb.embedders import TextEmbedder def pytest_addoption(parser: pytest.Parser): @@ -21,6 +24,19 @@ def pytest_collection_modifyitems(config: pytest.Config, items: pytest.Parser): item.add_marker(skip_slow) +class DummyTextEmbedder(TextEmbedder): + def encode(self, text: str | list[str]) -> np.ndarray: + if isinstance(text, str): + batch_size = 1 + else: + batch_size = len(text) + + return np.random.random((batch_size, self.get_output_dim())) + + def get_output_dim(self) -> int: + return 32 + + @pytest.fixture(scope="module") -def embedder(model_name_or_path: str = "prajjwal1/bert-tiny"): - return SentenceBertEmbedder(model_name_or_path=model_name_or_path) +def embedder(): + return DummyTextEmbedder() From 6fe095020a9ea8e20da5de856fdff297680f27eb Mon Sep 17 00:00:00 2001 From: Ryokan Ri Date: Fri, 10 May 2024 13:40:07 +0900 Subject: [PATCH 3/5] debug main --- src/jmteb/__main__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/jmteb/__main__.py b/src/jmteb/__main__.py index 3dfdb5d..bb9af7f 100644 --- a/src/jmteb/__main__.py +++ b/src/jmteb/__main__.py @@ -40,7 +40,8 @@ def main( logger.info(f"Results for {eval_name}\n{json.dumps(metrics.as_dict(), indent=4, ensure_ascii=False)}") - logger.info(f"Saving result summary to {Path(save_dir) / 'summary.json'}") + if save_dir: + logger.info(f"Saving result summary to {Path(save_dir) / 'summary.json'}") score_recorder.record_summary() From e22aabb40a96f865811881512a7dfab271f5f228 Mon Sep 17 00:00:00 2001 From: Ryokan Ri Date: Fri, 10 May 2024 13:40:23 +0900 Subject: [PATCH 4/5] add test for main --- tests/test_main.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 tests/test_main.py diff --git a/tests/test_main.py b/tests/test_main.py new file mode 100644 index 0000000..4ac552e --- /dev/null +++ b/tests/test_main.py @@ -0,0 +1,33 @@ +import subprocess +import tempfile +from pathlib import Path + +from evaluator.test_sts_evaluator import DummySTSDataset + +from jmteb.__main__ import main +from jmteb.evaluators import STSEvaluator + + +def test_main(embedder): + main( + text_embedder=embedder, + evaluators={"sts": STSEvaluator(val_dataset=DummySTSDataset(), test_dataset=DummySTSDataset())}, + save_dir=None, + overwrite_cache=False, + ) + + +def test_main_cli(): + with tempfile.TemporaryDirectory() as f: + # fmt: off + command = [ + "python", "-m", "jmteb", + "--embedder", "tests.conftest.DummyTextEmbedder", + "--save_dir", f, + "--eval_include", '["jsts"]', + ] + # fmt: on + result = subprocess.run(command) + assert result.returncode == 0 + + assert (Path(f) / "summary.json").exists() From bbbd23899d5d18df09ad83bd18064f5175830273 Mon Sep 17 00:00:00 2001 From: lsz05 Date: Thu, 16 May 2024 14:29:05 +0900 Subject: [PATCH 5/5] Version bump-up to 1.1.1 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 5f78e66..da5bbaf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ description = "The evaluation scripts for JMTEB (Japanese Massive Text Embedding name = "JMTEB" packages = [{from = "src", include = "jmteb"}] readme = "README.md" -version = "1.1.0" +version = "1.1.1" [tool.poetry.dependencies] python = ">=3.8.1,<4.0"