Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dev to main] v1.2.0 #35

Merged
merged 20 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/pull_request_template.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ base branchを`dev`にするよう、お願いいたします。

## 動作確認
- [ ] テストが通ることを確認した
- [ ] マージ先がdevブランチであることを確認した
- [ ] ...

<!--
Expand Down
5 changes: 2 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ description = "The evaluation scripts for JMTEB (Japanese Massive Text Embedding
name = "JMTEB"
packages = [{from = "src", include = "jmteb"}]
readme = "README.md"
version = "1.1.1"
version = "1.2.0"

[tool.poetry.dependencies]
python = ">=3.10,<4.0"
Expand All @@ -30,6 +30,7 @@ smart-open = "^7.0.1"
openai = "^1.16.2"
pytest-mock = "^3.14.0"
tiktoken = "^0.6.0"
numpy = "^1.26"

[tool.poetry.group.dev.dependencies]
black = "^23.11.0"
Expand Down
21 changes: 16 additions & 5 deletions src/jmteb/embedders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@ class TextEmbedder(ABC):
The base class of text embedder.
"""

def encode(self, text: str | list[str]) -> np.ndarray:
def encode(self, text: str | list[str], prefix: str | None = None) -> np.ndarray:
"""Convert a text string or a list of texts to embedding.

Args:
text (str | list[str]): text string, or a list of texts.
prefix (str, optional): the prefix to use for encoding. Default to None.
"""
raise NotImplementedError

Expand All @@ -31,14 +32,20 @@ def get_output_dim(self) -> int:
raise NotImplementedError

def _batch_encode_and_save_on_disk(
self, text_list: list[str], save_path: str | PathLike[str], batch_size: int = 64, dtype: str = "float32"
self,
text_list: list[str],
save_path: str | PathLike[str],
prefix: str | None = None,
batch_size: int = 64,
dtype: str = "float32",
) -> np.memmap:
"""
Encode a list of texts and save the embeddings on disk using memmap.

Args:
text_list (list[str]): list of texts
save_path (str): path to save the embeddings
prefix (str, optional): the prefix to use for encoding. Default to None.
dtype (str, optional): data type. Defaults to "float32".
batch_size (int): batch size. Defaults to 64.
"""
Expand All @@ -50,7 +57,7 @@ def _batch_encode_and_save_on_disk(
with tqdm.tqdm(total=num_samples, desc="Encoding") as pbar:
for i in range(0, num_samples, batch_size):
batch = text_list[i : i + batch_size]
batch_embeddings = self.encode(batch)
batch_embeddings = self.encode(batch, prefix=prefix)
batch_embeddings = np.asarray(batch_embeddings, dtype=dtype)
embeddings[i : i + batch_size] = batch_embeddings
pbar.update(len(batch))
Expand All @@ -61,6 +68,7 @@ def _batch_encode_and_save_on_disk(
def batch_encode_with_cache(
self,
text_list: list[str],
prefix: str | None = None,
cache_path: str | PathLike[str] | None = None,
overwrite_cache: bool = False,
batch_size: int = 64,
Expand All @@ -71,6 +79,7 @@ def batch_encode_with_cache(

Args:
text_list (list[str]): list of texts
prefix (str, optional): the prefix to use for encoding. Default to None.
cache_path (str, optional): path to save the embeddings. Defaults to None.
overwrite_cache (bool, optional): whether to overwrite the cache. Defaults to False.
batch_size (int): batch size. Defaults to 64.
Expand All @@ -79,12 +88,14 @@ def batch_encode_with_cache(

if cache_path is None:
logger.info("Encoding embeddings")
return self.encode(text_list).astype(dtype)
return self.encode(text_list, prefix=prefix).astype(dtype)

if Path(cache_path).exists() and not overwrite_cache:
logger.info(f"Loading embeddings from {cache_path}")
return np.memmap(cache_path, dtype=dtype, mode="r", shape=(len(text_list), self.get_output_dim()))

logger.info(f"Encoding and saving embeddings to {cache_path}")
embeddings = self._batch_encode_and_save_on_disk(text_list, cache_path, batch_size=batch_size, dtype=dtype)
embeddings = self._batch_encode_and_save_on_disk(
text_list, cache_path, prefix=prefix, batch_size=batch_size, dtype=dtype
)
return embeddings
9 changes: 5 additions & 4 deletions src/jmteb/embedders/openai_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,13 @@ def __init__(self, model: str = "text-embedding-3-small", dim: int | None = None
else:
self.dim = dim

def encode(self, text: str | list[str]) -> np.ndarray:
def encode(self, text: str | list[str], prefix: str | None = None) -> np.ndarray:
kwargs = {"dimensions": self.dim} if self.model != "text-embedding-ada-002" else {}
# specifying `dimensions` is not allowed for "text-embedding-ada-002"
if isinstance(text, str):
token_ids: list[int] = self.encode_and_truncate_text(text)
token_ids: list[int] = self.encode_and_truncate_text(text, prefix)
else:
token_ids: list[list[int]] = [self.encode_and_truncate_text(t) for t in text]
token_ids: list[list[int]] = [self.encode_and_truncate_text(t, prefix) for t in text]
result = np.asarray(
[
data.embedding
Expand All @@ -84,10 +84,11 @@ def encode(self, text: str | list[str]) -> np.ndarray:
def get_output_dim(self) -> int:
return self.dim

def encode_and_truncate_text(self, text: str) -> list[int]:
def encode_and_truncate_text(self, text: str, prefix: str | None = None) -> list[int]:
# Refer to https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken
# return a list of token IDs
if not text:
text = " "
logger.warning("Found empty string!")
# Ignore prefix in OpenAIEmbedder
return self.encoding.encode(text)[: self.max_token_length]
26 changes: 24 additions & 2 deletions src/jmteb/embedders/sbert_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,42 @@ def __init__(
batch_size: int = 32,
device: str | None = None,
normalize_embeddings: bool = False,
max_seq_length: int | None = None,
add_eos: bool = False,
tokenizer_kwargs: dict | None = None,
) -> None:
self.model = SentenceTransformer(model_name_or_path)
self.model = SentenceTransformer(model_name_or_path, trust_remote_code=True, tokenizer_kwargs=tokenizer_kwargs)
if max_seq_length:
self.model.max_seq_length = max_seq_length

self.batch_size = batch_size
self.device = device
self.normalize_embeddings = normalize_embeddings
self.max_seq_length = getattr(self.model, "max_seq_length", None)
self.add_eos = add_eos

def encode(self, text: str | list[str]) -> np.ndarray:
def encode(self, text: str | list[str], prefix: str | None = None) -> np.ndarray:
if self.add_eos:
text = self._add_eos_func(text)
return self.model.encode(
text,
prompt=prefix,
convert_to_numpy=True,
batch_size=self.batch_size,
device=self.device,
normalize_embeddings=self.normalize_embeddings,
)

def _add_eos_func(self, text: str | list[str]) -> str | list[str]:
try:
eos_token = getattr(self.model.tokenizer, "eos_token")
except AttributeError:
return text

if isinstance(text, str):
return text + eos_token
elif isinstance(text, list):
return [t + eos_token for t in text]

def get_output_dim(self) -> int:
return self.model.get_sentence_embedding_dimension()
6 changes: 6 additions & 0 deletions src/jmteb/evaluators/classification/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class ClassificationEvaluator(EmbeddingEvaluator):
and delimited by comma, e.g., `macro, micro`.
The first one is specified as the main index.
classifiers (dict[str, Classifier]): classifiers to be evaluated.
prefix (str | None): prefix for sentences. Defaults to None.
"""

def __init__(
Expand All @@ -36,6 +37,7 @@ def __init__(
test_dataset: ClassificationDataset,
average: str = "macro",
classifiers: dict[str, Classifier] | None = None,
prefix: str | None = None,
) -> None:
self.train_dataset = train_dataset
self.val_dataset = val_dataset
Expand All @@ -49,6 +51,7 @@ def __init__(
for average_name in average
if average_name.strip().lower() in ("micro", "macro", "samples", "weighted", "binary")
] or ["macro"]
self.prefix = prefix
self.main_metric = f"{self.average[0]}_f1"

def __call__(
Expand All @@ -60,13 +63,15 @@ def __call__(
logger.info("Encoding training and validation sentences...")
X_train = model.batch_encode_with_cache(
[item.text for item in self.train_dataset],
prefix=self.prefix,
cache_path=Path(cache_dir) / "train_embeddings.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
)
y_train = [item.label for item in self.train_dataset]

X_val = model.batch_encode_with_cache(
[item.text for item in self.val_dataset],
prefix=self.prefix,
cache_path=Path(cache_dir) / "val_embeddings.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
)
Expand All @@ -79,6 +84,7 @@ def __call__(
else:
X_test = model.batch_encode_with_cache(
[item.text for item in self.test_dataset],
prefix=self.prefix,
cache_path=Path(cache_dir) / "test_embeddings.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
)
Expand Down
12 changes: 10 additions & 2 deletions src/jmteb/evaluators/clustering/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,13 @@ def __init__(
self,
val_dataset: ClusteringDataset,
test_dataset: ClusteringDataset,
prefix: str | None = None,
random_seed: int | None = None,
) -> None:
self.val_dataset = val_dataset
self.test_dataset = test_dataset
self.prefix = prefix
self.random_seed = random_seed
self.main_metric = "v_measure_score"

def __call__(
Expand All @@ -44,6 +48,7 @@ def __call__(
logger.info("Converting validation data to embeddings...")
val_embeddings = model.batch_encode_with_cache(
[item.text for item in self.val_dataset],
prefix=self.prefix,
cache_path=Path(cache_dir) / "val_embeddings.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
)
Expand All @@ -56,16 +61,19 @@ def __call__(
else:
test_embeddings = model.batch_encode_with_cache(
[item.text for item in self.test_dataset],
prefix=self.prefix,
cache_path=Path(cache_dir) / "test_embeddings.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
)
test_labels = [item.label for item in self.test_dataset]

n_clusters = len(set(test_labels))
model_constructors: dict[str, Callable[[], ClusterMixin]] = {
"MiniBatchKMeans": lambda: MiniBatchKMeans(n_clusters=n_clusters, n_init="auto"),
"MiniBatchKMeans": lambda: MiniBatchKMeans(
n_clusters=n_clusters, n_init="auto", random_state=self.random_seed
),
"AgglomerativeClustering": lambda: AgglomerativeClustering(n_clusters=n_clusters),
"BisectingKMeans": lambda: BisectingKMeans(n_clusters=n_clusters),
"BisectingKMeans": lambda: BisectingKMeans(n_clusters=n_clusters, random_state=self.random_seed),
"Birch": lambda: Birch(n_clusters=n_clusters),
}

Expand Down
10 changes: 9 additions & 1 deletion src/jmteb/evaluators/pair_classification/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,21 @@ class PairClassificationEvaluator(EmbeddingEvaluator):
Args:
val_dataset (PairClassificationDataset): validation dataset
test_dataset (PairClassificationDataset): test dataset
sentence1_prefix (str | None): prefix for sentence1. Defaults to None.
sentence2_prefix (str | None): prefix for sentence2. Defaults to None.
"""

def __init__(
self,
val_dataset: PairClassificationDataset,
test_dataset: PairClassificationDataset,
sentence1_prefix: str | None = None,
sentence2_prefix: str | None = None,
) -> None:
self.test_dataset = test_dataset
self.val_dataset = val_dataset
self.sentence1_prefix = sentence1_prefix
self.sentence2_prefix = sentence2_prefix
self.metrics = [ThresholdAccuracyMetric(), ThresholdF1Metric()]
self.main_metric = "binary_f1"

Expand Down Expand Up @@ -101,8 +107,8 @@ def __call__(
},
)

@staticmethod
def _convert_to_embeddings(
self,
model: TextEmbedder,
dataset: PairClassificationDataset,
split: str = "test",
Expand All @@ -111,11 +117,13 @@ def _convert_to_embeddings(
) -> tuple[np.ndarray, np.ndarray, list[float]]:
embeddings1 = model.batch_encode_with_cache(
[item.sentence1 for item in dataset],
prefix=self.sentence1_prefix,
cache_path=Path(cache_dir) / f"{split}_embeddings1.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
)
embeddings2 = model.batch_encode_with_cache(
[item.sentence2 for item in dataset],
prefix=self.sentence2_prefix,
cache_path=Path(cache_dir) / f"{split}_embeddings2.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
)
Expand Down
9 changes: 9 additions & 0 deletions src/jmteb/evaluators/reranking/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class RerankingEvaluator(EmbeddingEvaluator):
test_query_dataset (RerankingQueryDataset): test query dataset used for computing the scores
doc_dataset (RerankingDocDataset): document dataset
ndcg_at_k (list[int] | None): top k documents to consider in NDCG (Normalized Documented Cumulative Gain).
query_prefix (str | None): prefix for queries. Defaults to None.
doc_prefix (str | None): prefix for documents. Defaults to None.
"""

def __init__(
Expand All @@ -36,12 +38,16 @@ def __init__(
test_query_dataset: RerankingQueryDataset,
doc_dataset: RerankingDocDataset,
ndcg_at_k: list[int] | None = None,
query_prefix: str | None = None,
doc_prefix: str | None = None,
) -> None:
self.test_query_dataset = test_query_dataset
self.val_query_dataset = val_query_dataset
self.doc_dataset = doc_dataset
self.ndcg_at_k = ndcg_at_k or [10, 20, 40]
self.main_metric = f"ndcg@{self.ndcg_at_k[0]}"
self.query_prefix = query_prefix
self.doc_prefix = doc_prefix

def __call__(
self,
Expand All @@ -54,6 +60,7 @@ def __call__(

val_query_embeddings = model.batch_encode_with_cache(
text_list=[item.query for item in self.val_query_dataset],
prefix=self.query_prefix,
cache_path=Path(cache_dir) / "val_query.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
)
Expand All @@ -62,11 +69,13 @@ def __call__(
else:
test_query_embeddings = model.batch_encode_with_cache(
text_list=[item.query for item in self.test_query_dataset],
prefix=self.query_prefix,
cache_path=Path(cache_dir) / "test_query.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
)
doc_embeddings = model.batch_encode_with_cache(
text_list=[item.text for item in self.doc_dataset],
prefix=self.doc_prefix,
cache_path=Path(cache_dir) / "corpus.bin" if cache_dir is not None else None,
overwrite_cache=overwrite_cache,
)
Expand Down
Loading
Loading