Skip to content

Commit

Permalink
Merge pull request #25 from sbintuitions/dev
Browse files Browse the repository at this point in the history
[dev to main] v1.1.1
  • Loading branch information
ryokan0123 authored May 16, 2024
2 parents 818d4d8 + 6933a5f commit 365ead5
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 5 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ description = "The evaluation scripts for JMTEB (Japanese Massive Text Embedding
name = "JMTEB"
packages = [{from = "src", include = "jmteb"}]
readme = "README.md"
version = "1.1.0"
version = "1.1.1"

[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
Expand Down
24 changes: 23 additions & 1 deletion src/jmteb/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ def main(

logger.info(f"Results for {eval_name}\n{json.dumps(metrics.as_dict(), indent=4, ensure_ascii=False)}")

logger.info(f"Saving result summary to {Path(save_dir) / 'summary.json'}")
if save_dir:
logger.info(f"Saving result summary to {Path(save_dir) / 'summary.json'}")
score_recorder.record_summary()


Expand All @@ -57,17 +58,38 @@ def main(
parser.add_argument("--config", action=ActionConfigFile, help="Path to the config file.")
parser.add_argument("--save_dir", type=str, default=None, help="Directory to save the outputs")
parser.add_argument("--overwrite_cache", type=bool, default=False, help="Overwrite the save_dir if it exists")
parser.add_argument("--eval_include", type=list[str], default=None, help="Evaluators to include.")
parser.add_argument("--eval_exclude", type=list[str], default=None, help="Evaluators to exclude.")

args = parser.parse_args()

if args.eval_include is not None:
# check if the specified evaluators are valid
evaluator_keys = list(args.evaluators.keys())
for include_key in args.eval_include:
if include_key not in evaluator_keys:
raise ValueError(f"Invalid evaluator name: {include_key}")

# remove evaluators not in eval_include
for key in evaluator_keys:
if key not in args.eval_include:
args.evaluators.pop(key)

if args.eval_exclude is not None:
# check if the specified evaluators are valid
evaluator_keys = list(args.evaluators.keys())
for exclude_key in args.eval_exclude:
if exclude_key not in evaluator_keys:
raise ValueError(f"Invalid evaluator name: {exclude_key}")

# remove evaluators in eval_exclude
for key in evaluator_keys:
if key in args.eval_exclude:
args.evaluators.pop(key)

if len(args.evaluators) == 0:
raise ValueError("No evaluator is selected. Please check the config file or the command line arguments.")

args = parser.instantiate_classes(args)
if isinstance(args.evaluators, str):
raise ValueError(
Expand Down
22 changes: 19 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from __future__ import annotations

import numpy as np
import pytest

from jmteb.embedders.sbert_embedder import SentenceBertEmbedder
from jmteb.embedders import TextEmbedder


def pytest_addoption(parser: pytest.Parser):
Expand All @@ -21,6 +24,19 @@ def pytest_collection_modifyitems(config: pytest.Config, items: pytest.Parser):
item.add_marker(skip_slow)


class DummyTextEmbedder(TextEmbedder):
def encode(self, text: str | list[str]) -> np.ndarray:
if isinstance(text, str):
batch_size = 1
else:
batch_size = len(text)

return np.random.random((batch_size, self.get_output_dim()))

def get_output_dim(self) -> int:
return 32


@pytest.fixture(scope="module")
def embedder(model_name_or_path: str = "prajjwal1/bert-tiny"):
return SentenceBertEmbedder(model_name_or_path=model_name_or_path)
def embedder():
return DummyTextEmbedder()
33 changes: 33 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import subprocess
import tempfile
from pathlib import Path

from evaluator.test_sts_evaluator import DummySTSDataset

from jmteb.__main__ import main
from jmteb.evaluators import STSEvaluator


def test_main(embedder):
main(
text_embedder=embedder,
evaluators={"sts": STSEvaluator(val_dataset=DummySTSDataset(), test_dataset=DummySTSDataset())},
save_dir=None,
overwrite_cache=False,
)


def test_main_cli():
with tempfile.TemporaryDirectory() as f:
# fmt: off
command = [
"python", "-m", "jmteb",
"--embedder", "tests.conftest.DummyTextEmbedder",
"--save_dir", f,
"--eval_include", '["jsts"]',
]
# fmt: on
result = subprocess.run(command)
assert result.returncode == 0

assert (Path(f) / "summary.json").exists()

0 comments on commit 365ead5

Please sign in to comment.